mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 07:44:12 +00:00
Compare commits
75 Commits
feat/agent
...
fix/agent-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8ed46893a8 | ||
|
|
986b75a5c0 | ||
|
|
f6472e8709 | ||
|
|
50cfb7c9ec | ||
|
|
8281c688ca | ||
|
|
ad9d6eb5f4 | ||
|
|
b0a8dec59e | ||
|
|
ae800c21dd | ||
|
|
e149775616 | ||
|
|
e8eab14658 | ||
|
|
45db7d9cd4 | ||
|
|
d8779b0da2 | ||
|
|
5d7400c8bb | ||
|
|
6d36f2d239 | ||
|
|
8ac6bc0b5a | ||
|
|
27f2b2050d | ||
|
|
d53282bb44 | ||
|
|
aa3dc9002c | ||
|
|
82ead2735b | ||
|
|
27185eb98b | ||
|
|
e06dc48472 | ||
|
|
9f013d6590 | ||
|
|
d51bd90394 | ||
|
|
91b89d755e | ||
|
|
1409c81e76 | ||
|
|
4a43e165fb | ||
|
|
26838eb42a | ||
|
|
f0098d17ed | ||
|
|
54e9748240 | ||
|
|
48111a7f71 | ||
|
|
cdd610f94f | ||
|
|
daf11f4af1 | ||
|
|
e3a81f09a9 | ||
|
|
8d5a8f0153 | ||
|
|
4d25b598f9 | ||
|
|
3e9c3d0bb7 | ||
|
|
fec3bb4469 | ||
|
|
d4a09805a3 | ||
|
|
7e1d9894fb | ||
|
|
a8a8a5513c | ||
|
|
470e72c820 | ||
|
|
95eeb7b0d1 | ||
|
|
933b6abc13 | ||
|
|
1097bf314a | ||
|
|
beebba0340 | ||
|
|
4e27d82d68 | ||
|
|
cdeaf3f70b | ||
|
|
24839bb3e1 | ||
|
|
1650dbfbb1 | ||
|
|
fd11817044 | ||
|
|
6642fc6012 | ||
|
|
2710242982 | ||
|
|
1de84fdda0 | ||
|
|
4920821270 | ||
|
|
a16c729d5a | ||
|
|
3befbc1d68 | ||
|
|
62c413aca5 | ||
|
|
6887b501b8 | ||
|
|
f93bf131ab | ||
|
|
ef1f429437 | ||
|
|
c966bf1474 | ||
|
|
899df30bf6 | ||
|
|
8d8d3e3f2f | ||
|
|
5f0fa38ec6 | ||
|
|
cc1fe70d34 | ||
|
|
15ee1e11be | ||
|
|
c8b4a76530 | ||
|
|
6ee4eba86b | ||
|
|
357d2e8be8 | ||
|
|
b5accda3fe | ||
|
|
de4752a16b | ||
|
|
60427f1adf | ||
|
|
1a313c868d | ||
|
|
0b32b1988f | ||
|
|
e56c051d97 |
2
.github/actions/setup-poetry/action.yml
vendored
2
.github/actions/setup-poetry/action.yml
vendored
@@ -8,7 +8,7 @@ inputs:
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '1.8.4'
|
||||
default: '2.0.1'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
12
.github/workflows/api-tests.yml
vendored
12
.github/workflows/api-tests.yml
vendored
@@ -43,19 +43,17 @@ jobs:
|
||||
run: poetry install -C api --with dev
|
||||
|
||||
- name: Check dependencies in pyproject.toml
|
||||
run: poetry run -C api bash dev/pytest/pytest_artifacts.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_artifacts.sh
|
||||
|
||||
- name: Run Unit tests
|
||||
run: poetry run -C api bash dev/pytest/pytest_unit_tests.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_unit_tests.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
run: poetry run -P api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run mypy
|
||||
run: |
|
||||
pushd api
|
||||
poetry run python -m mypy --install-types --non-interactive .
|
||||
popd
|
||||
poetry run -C api python -m mypy --install-types --non-interactive .
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
@@ -75,4 +73,4 @@ jobs:
|
||||
ssrf_proxy
|
||||
|
||||
- name: Run Workflow
|
||||
run: poetry run -C api bash dev/pytest/pytest_workflow.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
47
.github/workflows/docker-build.yml
vendored
Normal file
47
.github/workflows/docker-build.yml
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
name: Build docker image
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/Dockerfile
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
9
.github/workflows/style.yml
vendored
9
.github/workflows/style.yml
vendored
@@ -39,12 +39,12 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
poetry run -C api ruff --version
|
||||
poetry run -C api ruff check ./api
|
||||
poetry run -C api ruff format --check ./api
|
||||
poetry run -C api ruff check ./
|
||||
poetry run -C api ruff format --check ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
run: poetry run -P api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
||||
- name: Lint hints
|
||||
if: failure()
|
||||
@@ -87,8 +87,7 @@ jobs:
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: echo "${{ steps.changed-files.outputs.all_changed_files }}" | sed 's|web/||g' | xargs pnpm eslint # wait for next lint support eslint v9
|
||||
|
||||
run: yarn run lint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -70,4 +70,4 @@ jobs:
|
||||
tidb
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
|
||||
run: poetry run -P api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -197,3 +197,6 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
|
||||
@@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="seguir en X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="seguir en LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Descargas de Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="suivre sur X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="suivre sur LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Tirages Docker" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)でフォロー"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedInでフォロー"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -25,6 +25,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -22,6 +22,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="X(Twitter)'da takip et"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="LinkedIn'da takip et"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Çekmeleri" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
@@ -62,8 +65,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||

|
||||
|
||||
|
||||
Özür dilerim, haklısınız. Daha anlamlı ve akıcı bir çeviri yapmaya çalışayım. İşte güncellenmiş çeviri:
|
||||
|
||||
**3. Prompt IDE**:
|
||||
Komut istemlerini oluşturmak, model performansını karşılaştırmak ve sohbet tabanlı uygulamalara metin-konuşma gibi ek özellikler eklemek için kullanıcı dostu bir arayüz.
|
||||
|
||||
@@ -150,8 +151,6 @@ Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edi
|
||||
## Dify'ı Kullanma
|
||||
|
||||
- **Cloud </br>**
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
-
|
||||
Herkesin sıfır kurulumla denemesi için bir [Dify Cloud](https://dify.ai) hizmeti sunuyoruz. Bu hizmet, kendi kendine dağıtılan versiyonun tüm yeteneklerini sağlar ve sandbox planında 200 ücretsiz GPT-4 çağrısı içerir.
|
||||
|
||||
- **Dify Topluluk Sürümünü Kendi Sunucunuzda Barındırma</br>**
|
||||
@@ -177,8 +176,6 @@ GitHub'da Dify'a yıldız verin ve yeni sürümlerden anında haberdar olun.
|
||||
>- RAM >= 4GB
|
||||
|
||||
</br>
|
||||
İşte verdiğiniz metnin Türkçe çevirisi, kod bloğu içinde:
|
||||
|
||||
Dify sunucusunu başlatmanın en kolay yolu, [docker-compose.yml](docker/docker-compose.yaml) dosyamızı çalıştırmaktır. Kurulum komutunu çalıştırmadan önce, makinenizde [Docker](https://docs.docker.com/get-docker/) ve [Docker Compose](https://docs.docker.com/compose/install/)'un kurulu olduğundan emin olun:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -21,6 +21,9 @@
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="theo dõi trên X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="theo dõi trên LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
|
||||
@@ -422,8 +422,7 @@ POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Plugin configuration
|
||||
PLUGIN_API_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_API_URL=http://127.0.0.1:5002
|
||||
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
|
||||
PLUGIN_DAEMON_URL=http://127.0.0.1:5002
|
||||
PLUGIN_REMOTE_INSTALL_PORT=5003
|
||||
PLUGIN_REMOTE_INSTALL_HOST=localhost
|
||||
@@ -436,7 +435,7 @@ MARKETPLACE_ENABLED=true
|
||||
MARKETPLACE_API_URL=https://marketplace.dify.ai
|
||||
|
||||
# Endpoint configuration
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
|
||||
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
@@ -53,10 +53,12 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
|
||||
@@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install Poetry
|
||||
ENV POETRY_VERSION=1.8.4
|
||||
ENV POETRY_VERSION=2.0.1
|
||||
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# RUN pip install --no-cache-dir poetry==${POETRY_VERSION} -i https://mirrors.aliyun.com/pypi/simple/
|
||||
@@ -48,15 +48,18 @@ ENV TZ=UTC
|
||||
|
||||
WORKDIR /app/api
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# if you located in China, you can use aliyun mirror to speed up
|
||||
# && echo "deb http://mirrors.aliyun.com/debian testing main" > /etc/apt/sources.list \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
RUN \
|
||||
apt-get update \
|
||||
# Install dependencies
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
# basic environment
|
||||
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
|
||||
# For Security
|
||||
# expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
fonts-noto-cjk \
|
||||
# install libmagic to support the use of python-magic guess MIMETYPE
|
||||
libmagic1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -79,7 +82,6 @@ COPY . /app/api/
|
||||
COPY docker/entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
|
||||
|
||||
@@ -79,5 +79,5 @@
|
||||
2. Run the tests locally with mocked system environment variables in `tool.pytest_env` section in `pyproject.toml`
|
||||
|
||||
```bash
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
poetry run -P api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
@@ -141,10 +141,10 @@ class PluginConfig(BaseSettings):
|
||||
|
||||
PLUGIN_DAEMON_URL: HttpUrl = Field(
|
||||
description="Plugin API URL",
|
||||
default="http://plugin:5002",
|
||||
default="http://localhost:5002",
|
||||
)
|
||||
|
||||
PLUGIN_API_KEY: str = Field(
|
||||
PLUGIN_DAEMON_KEY: str = Field(
|
||||
description="Plugin API key",
|
||||
default="plugin-api-key",
|
||||
)
|
||||
@@ -200,7 +200,7 @@ class EndpointConfig(BaseSettings):
|
||||
)
|
||||
|
||||
CONSOLE_WEB_URL: str = Field(
|
||||
description="Base URL for the console web interface," "used for frontend references and CORS configuration",
|
||||
description="Base URL for the console web interface,used for frontend references and CORS configuration",
|
||||
default="",
|
||||
)
|
||||
|
||||
@@ -556,6 +556,11 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
FORGOT_PASSWORD_LOCKOUT_DURATION: PositiveInt = Field(
|
||||
description="Time (in seconds) a user must wait before retrying password reset after exceeding the rate limit.",
|
||||
default=86400,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,40 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt
|
||||
from pydantic import Field, NonNegativeInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class HostedCreditConfig(BaseSettings):
|
||||
HOSTED_MODEL_CREDIT_CONFIG: str = Field(
|
||||
description="Model credit configuration in format 'model:credits,model:credits', e.g., 'gpt-4:20,gpt-4o:10'",
|
||||
default="",
|
||||
)
|
||||
|
||||
def get_model_credits(self, model_name: str) -> int:
|
||||
"""
|
||||
Get credit value for a specific model name.
|
||||
Returns 1 if model is not found in configuration (default credit).
|
||||
|
||||
:param model_name: The name of the model to search for
|
||||
:return: The credit value for the model
|
||||
"""
|
||||
if not self.HOSTED_MODEL_CREDIT_CONFIG:
|
||||
return 1
|
||||
|
||||
try:
|
||||
credit_map = dict(
|
||||
item.strip().split(":", 1) for item in self.HOSTED_MODEL_CREDIT_CONFIG.split(",") if ":" in item
|
||||
)
|
||||
|
||||
# Search for matching model pattern
|
||||
for pattern, credit in credit_map.items():
|
||||
if pattern.strip() == model_name:
|
||||
return int(credit)
|
||||
return 1 # Default quota if no match found
|
||||
except (ValueError, AttributeError):
|
||||
return 1 # Return default quota if parsing fails
|
||||
|
||||
|
||||
class HostedOpenAiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for hosted OpenAI service
|
||||
@@ -181,7 +212,7 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_APP_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching app templates: remote, db, or builtin" " default to remote,",
|
||||
description="Mode for fetching app templates: remote, db, or builtin default to remote,",
|
||||
default="remote",
|
||||
)
|
||||
|
||||
@@ -202,5 +233,7 @@ class HostedServiceConfig(
|
||||
HostedZhipuAIConfig,
|
||||
# moderation
|
||||
HostedModerationConfig,
|
||||
# credit config
|
||||
HostedCreditConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.0.0-beta.1",
|
||||
default="1.0.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import urllib.parse
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
try:
|
||||
import magic
|
||||
except ImportError:
|
||||
if platform.system() == "Windows":
|
||||
warnings.warn(
|
||||
"To use python-magic guess MIMETYPE, you need to run `pip install python-magic-bin`", stacklevel=2
|
||||
)
|
||||
elif platform.system() == "Darwin":
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to run `brew install libmagic`", stacklevel=2)
|
||||
elif platform.system() == "Linux":
|
||||
warnings.warn(
|
||||
"To use python-magic guess MIMETYPE, you need to run `sudo apt-get install libmagic1`", stacklevel=2
|
||||
)
|
||||
else:
|
||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||
magic = None # type: ignore
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
@@ -47,6 +67,13 @@ def guess_file_info_from_response(response: httpx.Response):
|
||||
# If guessing fails, use Content-Type from response headers
|
||||
mimetype = response.headers.get("Content-Type", "application/octet-stream")
|
||||
|
||||
# Use python-magic to guess MIME type if still unknown or generic
|
||||
if mimetype == "application/octet-stream" and magic is not None:
|
||||
try:
|
||||
mimetype = magic.from_buffer(response.content[:1024], mime=True)
|
||||
except magic.MagicException:
|
||||
pass
|
||||
|
||||
extension = os.path.splitext(filename)[1]
|
||||
|
||||
# Ensure filename has an extension
|
||||
|
||||
@@ -59,7 +59,7 @@ class InsertExploreAppListApi(Resource):
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f'App \'{args["app_id"]}\' is not found')
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
site = app.site
|
||||
if not site:
|
||||
|
||||
@@ -22,7 +22,7 @@ from controllers.console.wraps import account_initialization_required, setup_req
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models import App, AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -79,7 +79,7 @@ class ChatMessageTextApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
def post(self, app_model):
|
||||
def post(self, app_model: App):
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
try:
|
||||
@@ -98,9 +98,13 @@ class ChatMessageTextApi(Resource):
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
@@ -59,3 +59,9 @@ class EmailCodeAccountDeletionRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_account_deletion_rate_limit_exceeded"
|
||||
description = "Too many account deletion emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailPasswordResetLimitError(BaseHTTPException):
|
||||
error_code = "email_password_reset_limit"
|
||||
description = "Too many failed password reset attempts. Please try again in 24 hours."
|
||||
code = 429
|
||||
|
||||
@@ -8,7 +8,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError, PasswordMismatchError
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError, AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
@@ -65,6 +71,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
@@ -73,8 +83,10 @@ class ForgotPasswordCheckApi(Resource):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class DataSourceNotionListApi(Resource):
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info["notion_page_id"])
|
||||
# get all authorized pages
|
||||
data_source_bindings = session.execute(
|
||||
data_source_bindings = session.scalars(
|
||||
select(DataSourceOauthBinding).filter_by(
|
||||
tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
||||
)
|
||||
|
||||
@@ -52,12 +52,12 @@ class DatasetListApi(Resource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
if ids:
|
||||
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
|
||||
else:
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids
|
||||
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
|
||||
# check embedding setting
|
||||
@@ -457,7 +457,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -619,8 +619,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
match vector_type:
|
||||
case (
|
||||
VectorType.MILVUS
|
||||
| VectorType.RELYT
|
||||
VectorType.RELYT
|
||||
| VectorType.PGVECTOR
|
||||
| VectorType.TIDB_VECTOR
|
||||
| VectorType.CHROMA
|
||||
@@ -645,6 +644,7 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TIDB_ON_QDRANT
|
||||
| VectorType.LINDORM
|
||||
| VectorType.COUCHBASE
|
||||
| VectorType.MILVUS
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -362,8 +362,7 @@ class DatasetInitApi(Resource):
|
||||
)
|
||||
except InvokeAuthorizationError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -540,8 +539,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
return response.model_dump(), 200
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -168,8 +168,7 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -217,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -267,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -368,9 +365,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
if document.doc_form == "qa_model":
|
||||
data = {"content": row[0], "answer": row[1]}
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
data = {"content": row[0]}
|
||||
data = {"content": row.iloc[0]}
|
||||
result.append(data)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@@ -437,8 +434,7 @@ class ChildChunkAddApi(Resource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
@@ -32,7 +32,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -50,7 +50,7 @@ class MessageListApi(InstalledAppResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, current_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from controllers.console.wraps import setup_required
|
||||
@@ -29,4 +31,34 @@ class EnterpriseWorkspace(Resource):
|
||||
return {"message": "enterprise workspace created."}
|
||||
|
||||
|
||||
class EnterpriseWorkspaceNoOwnerEmail(Resource):
|
||||
@setup_required
|
||||
@enterprise_inner_api_only
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
tenant = TenantService.create_tenant(args["name"], is_from_dashboard=True)
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
resp = {
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"encrypt_public_key": tenant.encrypt_public_key,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"custom_config": json.loads(tenant.custom_config) if tenant.custom_config else {},
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
"updated_at": tenant.updated_at.isoformat() if tenant.updated_at else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"message": "enterprise workspace created.",
|
||||
"tenant": resp,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(EnterpriseWorkspace, "/enterprise/workspace")
|
||||
api.add_resource(EnterpriseWorkspaceNoOwnerEmail, "/enterprise/workspace/ownerless")
|
||||
|
||||
@@ -65,7 +65,7 @@ def enterprise_inner_api_user_auth(view):
|
||||
def plugin_inner_api_only(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.PLUGIN_API_KEY:
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
# get header 'X-Inner-Api-Key'
|
||||
|
||||
@@ -7,4 +7,4 @@ api = ExternalApi(bp)
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
from .dataset import dataset, document, hit_testing, segment, upload_file
|
||||
|
||||
@@ -31,8 +31,11 @@ class DatasetListApi(DatasetApiResource):
|
||||
# provider = request.args.get("provider", default="vendor")
|
||||
search = request.args.get("keyword", default=None, type=str)
|
||||
tag_ids = request.args.getlist("tag_ids")
|
||||
include_all = request.args.get("include_all", default="false").lower() == "true"
|
||||
|
||||
datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids)
|
||||
datasets, total = DatasetService.get_datasets(
|
||||
page, limit, tenant_id, current_user, search, tag_ids, include_all
|
||||
)
|
||||
# check embedding setting
|
||||
provider_manager = ProviderManager()
|
||||
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
|
||||
|
||||
@@ -18,6 +18,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
InvalidMetadataError,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
@@ -50,6 +51,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@@ -61,6 +65,28 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
@@ -107,6 +133,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("doc_type", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("doc_metadata", type=dict, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@@ -115,6 +143,32 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset is not exist.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
if args["text"]:
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
@@ -161,6 +215,30 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
args["doc_form"] = "text_model"
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@@ -228,6 +306,29 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if "doc_language" not in args:
|
||||
args["doc_language"] = "English"
|
||||
|
||||
# Validate metadata if provided
|
||||
if args.get("doc_type") or args.get("doc_metadata"):
|
||||
if not args.get("doc_type") or not args.get("doc_metadata"):
|
||||
raise InvalidMetadataError("Both doc_type and doc_metadata must be provided when adding metadata")
|
||||
|
||||
if args["doc_type"] not in DocumentService.DOCUMENT_METADATA_SCHEMA:
|
||||
raise InvalidMetadataError(
|
||||
"Invalid doc_type. Must be one of: " + ", ".join(DocumentService.DOCUMENT_METADATA_SCHEMA.keys())
|
||||
)
|
||||
|
||||
if not isinstance(args["doc_metadata"], dict):
|
||||
raise InvalidMetadataError("doc_metadata must be a dictionary")
|
||||
|
||||
# Validate metadata schema based on doc_type
|
||||
if args["doc_type"] != "others":
|
||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[args["doc_type"]]
|
||||
for key, value in args["doc_metadata"].items():
|
||||
if key in metadata_schema and not isinstance(value, metadata_schema[key]):
|
||||
raise InvalidMetadataError(f"Invalid type for metadata field {key}")
|
||||
|
||||
# set to MetaDataConfig
|
||||
args["metadata"] = {"doc_type": args["doc_type"], "doc_metadata": args["doc_metadata"]}
|
||||
|
||||
# get dataset info
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
|
||||
@@ -53,8 +53,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -95,8 +94,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -175,8 +173,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
"No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
54
api/controllers/service_api/dataset/upload_file.py
Normal file
54
api/controllers/service_api/dataset/upload_file.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
|
||||
class UploadFileApi(DatasetApiResource):
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Get upload file."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check upload file
|
||||
if document.data_source_type != "upload_file":
|
||||
raise ValueError(f"Document data source type ({document.data_source_type}) is not upload_file.")
|
||||
data_source_info = document.data_source_info_dict
|
||||
if data_source_info and "upload_file_id" in data_source_info:
|
||||
file_id = data_source_info["upload_file_id"]
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
else:
|
||||
raise ValueError("Upload file id not found in document data source info.")
|
||||
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"url": url,
|
||||
"download_url": f"{url}&as_attachment=true",
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.timestamp(),
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(UploadFileApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/upload-file")
|
||||
@@ -195,7 +195,11 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(ApiToken.token == auth_token, ApiToken.last_used_at < cutoff_time, ApiToken.type == scope)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
.returning(ApiToken)
|
||||
)
|
||||
@@ -236,7 +240,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="service_api",
|
||||
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||
is_anonymous=user_id == "DEFAULT-USER",
|
||||
session_id=user_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
|
||||
@@ -39,7 +39,7 @@ class ConversationListApi(WebApiResource):
|
||||
|
||||
pinned = None
|
||||
if "pinned" in args and args["pinned"] is not None:
|
||||
pinned = True if args["pinned"] == "true" else False
|
||||
pinned = args["pinned"] == "true"
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -91,7 +91,7 @@ class MessageListApi(WebApiResource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -168,7 +168,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought or "",
|
||||
|
||||
@@ -8,16 +8,16 @@ from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.model import App, Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -191,7 +191,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
assert model_schema is not None
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
@@ -247,29 +248,3 @@ class AgentChatAppRunner(AppRunner):
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
||||
@@ -167,8 +167,7 @@ class AppQueueManager:
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == "normal",
|
||||
Conversation.is_deleted.is_(False),
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
|
||||
@@ -145,7 +145,7 @@ class MessageCycleManage:
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
|
||||
@@ -62,8 +62,9 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
"[External data tool] API query failed, variable: {}, "
|
||||
"error: api_based_extension_id is invalid".format(self.variable)
|
||||
"[External data tool] API query failed, variable: {}, error: api_based_extension_id is invalid".format(
|
||||
self.variable
|
||||
)
|
||||
)
|
||||
|
||||
# decrypt api_key
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str,
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}"
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
|
||||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
|
||||
@@ -90,7 +90,7 @@ class File(BaseModel):
|
||||
def markdown(self) -> str:
|
||||
url = self.generate_url()
|
||||
if self.type == FileType.IMAGE:
|
||||
text = f''
|
||||
text = f""
|
||||
else:
|
||||
text = f"[{self.filename or url}]({url})"
|
||||
|
||||
|
||||
@@ -11,15 +11,6 @@ from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
proxy_mounts = (
|
||||
{
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
if dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL
|
||||
else None
|
||||
)
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
@@ -50,7 +41,11 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
with httpx.Client(proxy=dify_config.SSRF_PROXY_ALL_URL) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
elif proxy_mounts:
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL),
|
||||
"https://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTPS_URL),
|
||||
}
|
||||
with httpx.Client(mounts=proxy_mounts) as client:
|
||||
response = client.request(method=method, url=url, **kwargs)
|
||||
else:
|
||||
|
||||
@@ -530,7 +530,6 @@ class IndexingRunner:
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
tokens = 0
|
||||
chunk_size = 10
|
||||
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
|
||||
# create keyword index
|
||||
create_keyword_thread = threading.Thread(
|
||||
@@ -539,11 +538,22 @@ class IndexingRunner:
|
||||
)
|
||||
create_keyword_thread.start()
|
||||
|
||||
max_workers = 10
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for i in range(0, len(documents), chunk_size):
|
||||
chunk_documents = documents[i : i + chunk_size]
|
||||
|
||||
# Distribute documents into multiple groups based on the hash values of page_content
|
||||
# This is done to prevent multiple threads from processing the same document,
|
||||
# Thereby avoiding potential database insertion deadlocks
|
||||
document_groups: list[list[Document]] = [[] for _ in range(max_workers)]
|
||||
for document in documents:
|
||||
hash = helper.generate_text_hash(document.page_content)
|
||||
group_index = int(hash, 16) % max_workers
|
||||
document_groups[group_index].append(document)
|
||||
for chunk_documents in document_groups:
|
||||
if len(chunk_documents) == 0:
|
||||
continue
|
||||
futures.append(
|
||||
executor.submit(
|
||||
self._process_chunk,
|
||||
|
||||
@@ -131,7 +131,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
|
||||
"Please help me predict the three most likely questions that human would ask, "
|
||||
"and keeping each question under 20 characters.\n"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response"
|
||||
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
|
||||
"The output must be an array in JSON format following the specified schema:\n"
|
||||
'["question1","question2","question3"]\n'
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
@@ -23,6 +23,7 @@ __all__ = [
|
||||
"AudioPromptMessageContent",
|
||||
"DocumentPromptMessageContent",
|
||||
"ImagePromptMessageContent",
|
||||
"LLMMode",
|
||||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -8,7 +8,7 @@ from core.model_runtime.entities.message_entities import AssistantPromptMessage,
|
||||
from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo
|
||||
|
||||
|
||||
class LLMMode(Enum):
|
||||
class LLMMode(StrEnum):
|
||||
"""
|
||||
Enum class for large language model mode.
|
||||
"""
|
||||
|
||||
@@ -3,8 +3,11 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
DefaultParameterName,
|
||||
ModelType,
|
||||
PriceConfig,
|
||||
PriceInfo,
|
||||
@@ -18,6 +21,7 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.manager.model import PluginModelManager
|
||||
|
||||
@@ -144,3 +148,102 @@ class AIModel(BaseModel):
|
||||
model=model,
|
||||
credentials=credentials or {},
|
||||
)
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return self._get_customizable_model_schema(model, credentials)
|
||||
|
||||
def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema and fill in the template
|
||||
"""
|
||||
schema = self.get_customizable_model_schema(model, credentials)
|
||||
|
||||
if not schema:
|
||||
return None
|
||||
|
||||
# fill in the template
|
||||
new_parameter_rules = []
|
||||
for parameter_rule in schema.parameter_rules:
|
||||
if parameter_rule.use_template:
|
||||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max and "max" in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule["max"]
|
||||
if not parameter_rule.min and "min" in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule["min"]
|
||||
if not parameter_rule.default and "default" in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule["default"]
|
||||
if not parameter_rule.precision and "precision" in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule["precision"]
|
||||
if not parameter_rule.required and "required" in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule["required"]
|
||||
if not parameter_rule.help and "help" in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule["help"]["en_US"],
|
||||
)
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.en_US
|
||||
and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"]
|
||||
if (
|
||||
parameter_rule.help
|
||||
and not parameter_rule.help.zh_Hans
|
||||
and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"])
|
||||
):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule["help"].get(
|
||||
"zh_Hans", default_parameter_rule["help"]["en_US"]
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_parameter_rules.append(parameter_rule)
|
||||
|
||||
schema.parameter_rules = new_parameter_rules
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: model schema
|
||||
"""
|
||||
return None
|
||||
|
||||
def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName) -> dict:
|
||||
"""
|
||||
Get default parameter rule for given name
|
||||
|
||||
:param name: parameter name
|
||||
:return: parameter rule
|
||||
"""
|
||||
default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name)
|
||||
|
||||
if not default_parameter_rule:
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
|
||||
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages by gpt2
|
||||
Some provider models do not provide an interface for obtaining the number of tokens.
|
||||
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
||||
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
||||
|
||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||
:return: number of tokens
|
||||
"""
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Any = None
|
||||
_lock = Lock()
|
||||
|
||||
@@ -43,5 +46,6 @@ class GPT2Tokenizer:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), "gpt2")
|
||||
_tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken")
|
||||
|
||||
return _tokenizer
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
- openai
|
||||
- deepseek
|
||||
- anthropic
|
||||
- azure_openai
|
||||
- google
|
||||
@@ -32,7 +33,6 @@
|
||||
- localai
|
||||
- volcengine_maas
|
||||
- openai_api_compatible
|
||||
- deepseek
|
||||
- hunyuan
|
||||
- siliconflow
|
||||
- perfxcloud
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-001
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-lite-preview-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Lite Preview 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 0121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,37 @@
|
||||
model: gemini-2.0-pro-exp-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Pro Exp 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2000000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1114
|
||||
label:
|
||||
en_US: Gemini exp 1114
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1121
|
||||
label:
|
||||
en_US: Gemini exp 1121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -1,42 +0,0 @@
|
||||
model: ernie-lite-pro-128k
|
||||
label:
|
||||
en_US: Ernie-Lite-Pro-128K
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
min: 0.1
|
||||
max: 1.0
|
||||
default: 0.8
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: min_output_tokens
|
||||
label:
|
||||
en_US: "Min Output Tokens"
|
||||
zh_Hans: "最小输出Token数"
|
||||
use_template: max_tokens
|
||||
min: 2
|
||||
max: 2048
|
||||
help:
|
||||
zh_Hans: 指定模型最小输出token数
|
||||
en_US: Specifies the lower limit on the length of generated results.
|
||||
- name: max_output_tokens
|
||||
label:
|
||||
en_US: "Max Output Tokens"
|
||||
zh_Hans: "最大输出Token数"
|
||||
use_template: max_tokens
|
||||
min: 2
|
||||
max: 2048
|
||||
default: 2048
|
||||
help:
|
||||
zh_Hans: 指定模型最大输出token数
|
||||
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
@@ -0,0 +1,66 @@
|
||||
model: glm-4-air-0111
|
||||
label:
|
||||
en_US: glm-4-air-0111
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.95
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 0.7
|
||||
help:
|
||||
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1,默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如:0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
|
||||
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
|
||||
- name: do_sample
|
||||
label:
|
||||
zh_Hans: 采样策略
|
||||
en_US: Sampling strategy
|
||||
type: boolean
|
||||
help:
|
||||
zh_Hans: do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
|
||||
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
|
||||
default: true
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 4095
|
||||
- name: web_search
|
||||
type: boolean
|
||||
label:
|
||||
zh_Hans: 联网搜索
|
||||
en_US: Web Search
|
||||
default: false
|
||||
help:
|
||||
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
|
||||
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
||||
@@ -87,6 +87,6 @@ class CommonValidator:
|
||||
if value.lower() not in {"true", "false"}:
|
||||
raise ValueError(f"Variable {credential_form_schema.variable} should be true or false")
|
||||
|
||||
value = True if value.lower() == "true" else False
|
||||
value = value.lower() == "true"
|
||||
|
||||
return value
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
class TracingProviderEnum(Enum):
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
@@ -56,5 +57,36 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
return v
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Opik tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
workspace: str | None = None
|
||||
url: str = "https://www.comet.com/opik/api/"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
469
api/core/ops/opik_trace/opik_trace.py
Normal file
469
api/core/ops/opik_trace/opik_trace.py
Normal file
@@ -0,0 +1,469 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, cast
|
||||
|
||||
from opik import Opik, Trace
|
||||
from opik.id_helpers import uuid4_to_uuid7
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import OpikConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def wrap_dict(key_name, data):
|
||||
"""Make sure that the input data is a dict"""
|
||||
if not isinstance(data, dict):
|
||||
return {key_name: data}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def wrap_metadata(metadata, **kwargs):
|
||||
"""Add common metatada to all Traces and Spans"""
|
||||
metadata["created_from"] = "dify"
|
||||
|
||||
metadata.update(kwargs)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def prepare_opik_uuid(user_datetime: Optional[datetime], user_uuid: Optional[str]):
|
||||
"""Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most
|
||||
messages and objects. The type-hints of BaseTraceInfo indicates that
|
||||
objects start_time and message_id could be null which means we cannot map
|
||||
it to a UUIDv7. Given that we have no way to identify that object
|
||||
uniquely, generate a new random one UUIDv7 in that case.
|
||||
"""
|
||||
|
||||
if user_datetime is None:
|
||||
user_datetime = datetime.now()
|
||||
|
||||
if user_uuid is None:
|
||||
user_uuid = str(uuid.uuid4())
|
||||
|
||||
return uuid4_to_uuid7(user_datetime, user_uuid)
|
||||
|
||||
|
||||
class OpikDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
opik_config: OpikConfig,
|
||||
):
|
||||
super().__init__(opik_config)
|
||||
self.opik_client = Opik(
|
||||
project_name=opik_config.project,
|
||||
workspace=opik_config.workspace,
|
||||
host=opik_config.url,
|
||||
api_key=opik_config.api_key,
|
||||
)
|
||||
self.project = opik_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
dify_trace_id = trace_info.workflow_run_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
workflow_metadata = wrap_metadata(
|
||||
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
|
||||
)
|
||||
root_span_id = None
|
||||
|
||||
if trace_info.message_id:
|
||||
dify_trace_id = trace_info.message_id
|
||||
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
|
||||
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["message", "workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id)
|
||||
span_data = {
|
||||
"id": root_span_id,
|
||||
"parent_span_id": None,
|
||||
"trace_id": opik_trace_id,
|
||||
"name": TraceTaskName.WORKFLOW_TRACE.value,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
else:
|
||||
trace_data = {
|
||||
"id": opik_trace_id,
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": workflow_metadata,
|
||||
"input": wrap_dict("input", trace_info.workflow_run_inputs),
|
||||
"output": wrap_dict("output", trace_info.workflow_run_outputs),
|
||||
"tags": ["workflow"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_trace(trace_data)
|
||||
|
||||
# through workflow_run_id get all_nodes_execution
|
||||
workflow_nodes_execution_id_records = (
|
||||
db.session.query(WorkflowNodeExecution.id)
|
||||
.filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
for node_execution_id_record in workflow_nodes_execution_id_records:
|
||||
node_execution = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecution.id,
|
||||
WorkflowNodeExecution.tenant_id,
|
||||
WorkflowNodeExecution.app_id,
|
||||
WorkflowNodeExecution.title,
|
||||
WorkflowNodeExecution.node_type,
|
||||
WorkflowNodeExecution.status,
|
||||
WorkflowNodeExecution.inputs,
|
||||
WorkflowNodeExecution.outputs,
|
||||
WorkflowNodeExecution.created_at,
|
||||
WorkflowNodeExecution.elapsed_time,
|
||||
WorkflowNodeExecution.process_data,
|
||||
WorkflowNodeExecution.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecution.id == node_execution_id_record.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not node_execution:
|
||||
continue
|
||||
|
||||
node_execution_id = node_execution.id
|
||||
tenant_id = node_execution.tenant_id
|
||||
app_id = node_execution.app_id
|
||||
node_name = node_execution.title
|
||||
node_type = node_execution.node_type
|
||||
status = node_execution.status
|
||||
if node_type == "llm":
|
||||
inputs = (
|
||||
json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
|
||||
)
|
||||
else:
|
||||
inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
|
||||
outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
execution_metadata = (
|
||||
json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
|
||||
)
|
||||
metadata = execution_metadata.copy()
|
||||
metadata.update(
|
||||
{
|
||||
"workflow_run_id": trace_info.workflow_run_id,
|
||||
"node_execution_id": node_execution_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": app_id,
|
||||
"app_name": node_name,
|
||||
"node_type": node_type,
|
||||
"status": status,
|
||||
}
|
||||
)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
provider = None
|
||||
model = None
|
||||
total_tokens = 0
|
||||
completion_tokens = 0
|
||||
prompt_tokens = 0
|
||||
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
run_type = "llm"
|
||||
provider = process_data.get("model_provider", None)
|
||||
model = process_data.get("model_name", "")
|
||||
metadata.update(
|
||||
{
|
||||
"ls_provider": provider,
|
||||
"ls_model_name": model,
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
total_tokens = outputs["usage"].get("total_tokens", 0)
|
||||
prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
|
||||
completion_tokens = outputs["usage"].get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
else:
|
||||
run_type = "tool"
|
||||
|
||||
parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
|
||||
|
||||
if not total_tokens:
|
||||
total_tokens = execution_metadata.get("total_tokens", 0)
|
||||
|
||||
span_data = {
|
||||
"trace_id": opik_trace_id,
|
||||
"id": prepare_opik_uuid(created_at, node_execution_id),
|
||||
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
|
||||
"name": node_type,
|
||||
"type": run_type,
|
||||
"start_time": created_at,
|
||||
"end_time": finished_at,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": wrap_dict("input", inputs),
|
||||
"output": wrap_dict("output", outputs),
|
||||
"tags": ["node_execution"],
|
||||
"project_name": self.project,
|
||||
"usage": {
|
||||
"total_tokens": total_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
},
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
# get message file data
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
metadata = trace_info.metadata
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
metadata["user_id"] = user_id
|
||||
metadata["file_list"] = file_list
|
||||
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
end_user_id = end_user_data.session_id
|
||||
metadata["end_user_id"] = end_user_id
|
||||
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, message_id),
|
||||
"name": TraceTaskName.MESSAGE_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": message_data.answer,
|
||||
"tags": ["message", str(trace_info.conversation_mode)],
|
||||
"project_name": self.project,
|
||||
}
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": "llm",
|
||||
"type": "llm",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(metadata),
|
||||
"input": {"input": trace_info.inputs},
|
||||
"output": {"output": message_data.answer},
|
||||
"tags": ["llm", str(trace_info.conversation_mode)],
|
||||
"usage": {
|
||||
"completion_tokens": trace_info.answer_tokens,
|
||||
"prompt_tokens": trace_info.message_tokens,
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
},
|
||||
"project_name": self.project,
|
||||
}
|
||||
self.add_span(span_data)
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.MODERATION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
"tags": ["moderation"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.suggested_question),
|
||||
"tags": ["suggested_question"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
"type": "tool",
|
||||
"start_time": start_time,
|
||||
"end_time": trace_info.end_time or trace_info.message_data.updated_at,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": {"documents": trace_info.documents},
|
||||
"tags": ["dataset_retrieval"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
span_data = {
|
||||
"trace_id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": trace_info.tool_name,
|
||||
"type": "tool",
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.tool_inputs),
|
||||
"output": wrap_dict("output", trace_info.tool_outputs),
|
||||
"tags": ["tool", trace_info.tool_name],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
trace_data = {
|
||||
"id": prepare_opik_uuid(trace_info.start_time, trace_info.message_id),
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": trace_info.inputs,
|
||||
"output": trace_info.outputs,
|
||||
"tags": ["generate_name"],
|
||||
"project_name": self.project,
|
||||
}
|
||||
|
||||
trace = self.add_trace(trace_data)
|
||||
|
||||
span_data = {
|
||||
"trace_id": trace.id,
|
||||
"name": TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
"start_time": trace_info.start_time,
|
||||
"end_time": trace_info.end_time,
|
||||
"metadata": wrap_metadata(trace_info.metadata),
|
||||
"input": wrap_dict("input", trace_info.inputs),
|
||||
"output": wrap_dict("output", trace_info.outputs),
|
||||
"tags": ["generate_name"],
|
||||
}
|
||||
|
||||
self.add_span(span_data)
|
||||
|
||||
def add_trace(self, opik_trace_data: dict) -> Trace:
|
||||
try:
|
||||
trace = self.opik_client.trace(**opik_trace_data)
|
||||
logger.debug("Opik Trace created successfully")
|
||||
return trace
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, opik_span_data: dict):
|
||||
try:
|
||||
self.opik_client.span(**opik_span_data)
|
||||
logger.debug("Opik Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Opik Failed to create span: {str(e)}")
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
self.opik_client.auth_check()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"Opik API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.opik_client.get_project_url(project_name=self.project)
|
||||
except Exception as e:
|
||||
logger.info(f"Opik get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Opik get run url failed: {str(e)}")
|
||||
@@ -17,6 +17,7 @@ from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
LangfuseConfig,
|
||||
LangSmithConfig,
|
||||
OpikConfig,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
@@ -32,6 +33,7 @@ from core.ops.entities.trace_entity import (
|
||||
)
|
||||
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
|
||||
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
|
||||
from core.ops.opik_trace.opik_trace import OpikDataTrace
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@@ -52,6 +54,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
|
||||
"other_keys": ["project", "endpoint"],
|
||||
"trace_instance": LangSmithDataTrace,
|
||||
},
|
||||
TracingProviderEnum.OPIK.value: {
|
||||
"config_class": OpikConfig,
|
||||
"secret_keys": ["api_key"],
|
||||
"other_keys": ["project", "url", "workspace"],
|
||||
"trace_instance": OpikDataTrace,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -197,3 +197,9 @@ class PluginDependency(BaseModel):
|
||||
|
||||
type: Type
|
||||
value: Github | Marketplace | Package
|
||||
current_identifier: Optional[str] = None
|
||||
|
||||
|
||||
class MissingPluginDependency(BaseModel):
|
||||
plugin_unique_identifier: str
|
||||
current_identifier: Optional[str] = None
|
||||
|
||||
@@ -30,7 +30,7 @@ from core.plugin.manager.exc import (
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_API_KEY
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
@@ -175,14 +176,16 @@ class PluginInstallationManager(BasePluginManager):
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_missing_dependencies(self, tenant_id: str, plugin_unique_identifiers: list[str]) -> list[str]:
|
||||
def fetch_missing_dependencies(
|
||||
self, tenant_id: str, plugin_unique_identifiers: list[str]
|
||||
) -> list[MissingPluginDependency]:
|
||||
"""
|
||||
Fetch missing dependencies
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/missing",
|
||||
list[str],
|
||||
list[MissingPluginDependency],
|
||||
data={"plugin_unique_identifiers": plugin_unique_identifiers},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@@ -48,8 +48,10 @@ class PluginToolManager(BasePluginManager):
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
|
||||
@@ -23,7 +23,12 @@ from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
ConfigurateMethod,
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@@ -839,11 +844,18 @@ class ProviderManager:
|
||||
:return:
|
||||
"""
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
if ConfigurateMethod.PREDEFINED_MODEL in provider_entity.configurate_methods:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
if provider_entity.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
else:
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
model_settings: list[ModelSettings] = []
|
||||
if not provider_model_settings:
|
||||
|
||||
@@ -258,7 +258,7 @@ class LindormVectorStore(BaseVector):
|
||||
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
|
||||
ivfpq_m = kwargs.pop("ivfpq_m", dimension)
|
||||
nlist = kwargs.pop("nlist", 1000)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", True if nlist >= 5000 else False)
|
||||
centroids_use_hnsw = kwargs.pop("centroids_use_hnsw", nlist >= 5000)
|
||||
centroids_hnsw_m = kwargs.pop("centroids_hnsw_m", 24)
|
||||
centroids_hnsw_ef_construct = kwargs.pop("centroids_hnsw_ef_construct", 500)
|
||||
centroids_hnsw_ef_search = kwargs.pop("centroids_hnsw_ef_search", 100)
|
||||
@@ -305,7 +305,7 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
|
||||
if method_name == "ivfpq":
|
||||
ivfpq_m = kwargs["ivfpq_m"]
|
||||
nlist = kwargs["nlist"]
|
||||
centroids_use_hnsw = True if nlist > 10000 else False
|
||||
centroids_use_hnsw = nlist > 10000
|
||||
centroids_hnsw_m = 24
|
||||
centroids_hnsw_ef_construct = 500
|
||||
centroids_hnsw_ef_search = 100
|
||||
|
||||
@@ -57,6 +57,11 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
) using heap;
|
||||
"""
|
||||
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
|
||||
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
|
||||
"""
|
||||
|
||||
|
||||
class PGVector(BaseVector):
|
||||
def __init__(self, collection_name: str, config: PGVectorConfig):
|
||||
@@ -205,7 +210,10 @@ class PGVector(BaseVector):
|
||||
with self._get_cursor() as cur:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
|
||||
# TODO: create index https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
# PG hnsw index only support 2000 dimension or less
|
||||
# ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing
|
||||
if dimension <= 2000:
|
||||
cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed transform embedding")
|
||||
cache_embeddings = []
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -14,48 +14,47 @@ class FirecrawlApp:
|
||||
if self.api_key is None and self.base_url == "https://api.firecrawl.dev":
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
json_data = {"url": url}
|
||||
def scrape_url(self, url, params=None) -> dict[str, Any]:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape
|
||||
headers = self._prepare_headers()
|
||||
json_data = {
|
||||
"url": url,
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"timeout": 30000,
|
||||
}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(f"{self.base_url}/v0/scrape", headers=headers, json=json_data)
|
||||
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
if response_data["success"] == True:
|
||||
data = response_data["data"]
|
||||
return {
|
||||
"title": data.get("metadata").get("title"),
|
||||
"description": data.get("metadata").get("description"),
|
||||
"source_url": data.get("metadata").get("sourceURL"),
|
||||
"markdown": data.get("markdown"),
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response_data["error"]}')
|
||||
|
||||
elif response.status_code in {402, 409, 500}:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}")
|
||||
data = response_data["data"]
|
||||
return self._extract_common_fields(data)
|
||||
elif response.status_code in {402, 409, 500, 429, 408}:
|
||||
self._handle_error(response, "scrape URL")
|
||||
return {} # Avoid additional exception after handling error
|
||||
else:
|
||||
raise Exception(f"Failed to scrape URL. Status code: {response.status_code}")
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
|
||||
headers = self._prepare_headers()
|
||||
json_data = {"url": url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", json_data, headers)
|
||||
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get("jobId")
|
||||
# There's also another two fields in the response: "success" (bool) and "url" (str)
|
||||
job_id = response.json().get("id")
|
||||
return cast(str, job_id)
|
||||
else:
|
||||
self._handle_error(response, "start crawl job")
|
||||
# FIXME: unreachable code for mypy
|
||||
return "" # unreachable
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
def check_crawl_status(self, job_id) -> dict[str, Any]:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f"{self.base_url}/v0/crawl/status/{job_id}", headers)
|
||||
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get("status") == "completed":
|
||||
@@ -66,42 +65,48 @@ class FirecrawlApp:
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and "metadata" in item and "markdown" in item:
|
||||
url_data = {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
url_data = self._extract_common_fields(item)
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
return {
|
||||
"status": "completed",
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
try:
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode("utf-8"))
|
||||
except Exception as e:
|
||||
raise Exception(f"Error saving crawl data: {e}")
|
||||
return self._format_crawl_status_response("completed", crawl_status_response, url_data_list)
|
||||
else:
|
||||
return {
|
||||
"status": crawl_status_response.get("status"),
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("current"),
|
||||
"data": [],
|
||||
}
|
||||
|
||||
return self._format_crawl_status_response(
|
||||
crawl_status_response.get("status"), crawl_status_response, []
|
||||
)
|
||||
else:
|
||||
self._handle_error(response, "check crawl status")
|
||||
# FIXME: unreachable code for mypy
|
||||
return {} # unreachable
|
||||
|
||||
def _prepare_headers(self):
|
||||
def _format_crawl_status_response(
|
||||
self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"status": status,
|
||||
"total": crawl_status_response.get("total"),
|
||||
"current": crawl_status_response.get("completed"),
|
||||
"data": url_data_list,
|
||||
}
|
||||
|
||||
def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"title": item.get("metadata", {}).get("title"),
|
||||
"description": item.get("metadata", {}).get("description"),
|
||||
"source_url": item.get("metadata", {}).get("sourceURL"),
|
||||
"markdown": item.get("markdown"),
|
||||
}
|
||||
|
||||
def _prepare_headers(self) -> dict[str, Any]:
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
@@ -110,7 +115,7 @@ class FirecrawlApp:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5) -> requests.Response:
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
@@ -119,6 +124,6 @@ class FirecrawlApp:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
def _handle_error(self, response, action) -> None:
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
|
||||
|
||||
@@ -13,9 +13,10 @@ class FirecrawlWebExtractor(BaseExtractor):
|
||||
api_key: The API key for Firecrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
only_main_content: Only return the main content of the page excluding headers, navs, footers, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False):
|
||||
def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
|
||||
@@ -358,8 +358,7 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} "
|
||||
f"and notion workspace {notion_workspace_id}"
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
|
||||
@@ -47,6 +47,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
embedding_model_instance=kwargs.get("embedding_model_instance"),
|
||||
)
|
||||
for document in documents:
|
||||
if kwargs.get("preview") and len(all_documents) >= 10:
|
||||
return all_documents
|
||||
# document clean
|
||||
document_text = CleanProcessor.clean(document.page_content, process_rule)
|
||||
document.page_content = document_text
|
||||
|
||||
@@ -112,7 +112,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
df = pd.read_csv(file)
|
||||
text_docs = []
|
||||
for index, row in df.iterrows():
|
||||
data = Document(page_content=row[0], metadata={"answer": row[1]})
|
||||
data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]})
|
||||
text_docs.append(data)
|
||||
if len(text_docs) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
|
||||
@@ -94,9 +94,9 @@ class ApiTool(Tool):
|
||||
if "api_key_header_prefix" in credentials:
|
||||
api_key_header_prefix = credentials["api_key_header_prefix"]
|
||||
if api_key_header_prefix == "basic" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
|
||||
credentials["api_key_value"] = f"Basic {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
|
||||
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
|
||||
credentials["api_key_value"] = f"Bearer {credentials['api_key_value']}"
|
||||
elif api_key_header_prefix == "custom":
|
||||
pass
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
class VariableMessage(BaseModel):
|
||||
variable_name: str = Field(..., description="The name of the variable")
|
||||
variable_value: str = Field(..., description="The value of the variable")
|
||||
variable_value: Any = Field(..., description="The value of the variable")
|
||||
stream: bool = Field(default=False, description="Whether the variable is streamed")
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -48,7 +48,9 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
"""
|
||||
return tool with given name
|
||||
"""
|
||||
tool_entity = next(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name)
|
||||
tool_entity = next(
|
||||
(tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
|
||||
)
|
||||
|
||||
if not tool_entity:
|
||||
raise ValueError(f"Tool with name {tool_name} not found")
|
||||
|
||||
@@ -39,7 +39,7 @@ class ToolFileMessageTransformer:
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
yield ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
@@ -115,4 +115,4 @@ class ToolFileMessageTransformer:
|
||||
|
||||
@classmethod
|
||||
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||
return f"/files/tools/{tool_file_id}{extension or '.bin'}"
|
||||
|
||||
@@ -5,6 +5,7 @@ from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
@@ -29,6 +30,10 @@ class ApiBasedToolSchemaParser:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
@@ -112,7 +117,7 @@ class ApiBasedToolSchemaParser:
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
@@ -144,7 +149,7 @@ class ApiBasedToolSchemaParser:
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f'{path}_{interface["method"]}'
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
|
||||
@@ -223,14 +223,14 @@ class WorkflowTool(Tool):
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
item["tool_file_id"] = item.get("related_id")
|
||||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value["tool_file_id"] = value.get("related_id")
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
@@ -240,3 +240,11 @@ class WorkflowTool(Tool):
|
||||
result[key] = value
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict) -> dict:
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_dict.get("related_id")
|
||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
file_dict["upload_file_id"] = file_dict.get("related_id")
|
||||
return file_dict
|
||||
|
||||
@@ -134,6 +134,10 @@ class ArrayStringSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||
value: Sequence[str]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return json.dumps(self.value)
|
||||
|
||||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||
|
||||
@@ -590,6 +590,8 @@ class Graph(BaseModel):
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids,
|
||||
)
|
||||
# Exclude conditional branch nodes
|
||||
and all(edge.run_condition is None for edge in reverse_edge_mapping.get(node_id, []))
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunM
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseAgentEvent,
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
@@ -501,7 +502,7 @@ class GraphEngine:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if not isinstance(event, BaseAgentEvent) and event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
|
||||
@@ -8,12 +8,12 @@ from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||
from core.plugin.manager.plugin import PluginInstallationManager
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, ParamsAutoGenerated
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event.event import RunCompletedEvent
|
||||
@@ -156,17 +156,37 @@ class AgentNode(ToolNode):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN.value) == ParamsAutoGenerated.CLOSE.value:
|
||||
params[key] = param.get("value", {}).get("value", "")
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN.value))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=tool.get("parameters", {}),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
)
|
||||
|
||||
@@ -179,14 +199,27 @@ class AgentNode(ToolNode):
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("descrption", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": tool_runtime.runtime.runtime_parameters,
|
||||
"provider_type": provider_type.value,
|
||||
for params in tool_runtime.entity.parameters:
|
||||
params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if params.name in manual_input_params
|
||||
else params.form
|
||||
)
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
)
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == "model-selector":
|
||||
value = cast(dict[str, Any], value)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user