mirror of
https://github.com/langgenius/dify.git
synced 2026-01-07 06:48:28 +00:00
Compare commits
191 Commits
fix/extern
...
fix/redis-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0c6ad1df64 | ||
|
|
40fb522f56 | ||
|
|
96d9951d5c | ||
|
|
d36201f7ff | ||
|
|
b46c7935b1 | ||
|
|
206e6e1e7c | ||
|
|
8685c0d48b | ||
|
|
b674c598f9 | ||
|
|
710230a294 | ||
|
|
169f7440ac | ||
|
|
57ec12eb6b | ||
|
|
2c26f77a25 | ||
|
|
95dc90e6b2 | ||
|
|
400392230b | ||
|
|
eca66f9577 | ||
|
|
121bb99cc2 | ||
|
|
cac1ef7ade | ||
|
|
d74d79b3d8 | ||
|
|
c6b28bc193 | ||
|
|
5d05574518 | ||
|
|
bf478aeba2 | ||
|
|
c9dfe1ad92 | ||
|
|
926609eb59 | ||
|
|
e32116b9a3 | ||
|
|
e11d5ac708 | ||
|
|
f6c3d4cadc | ||
|
|
3e9d271b52 | ||
|
|
ecc8beef3f | ||
|
|
b9afb7bcec | ||
|
|
b4041759f7 | ||
|
|
c3473b5b4f | ||
|
|
1b9bf9c62d | ||
|
|
ed96a6b6c0 | ||
|
|
4989d0c904 | ||
|
|
9a5bdae07f | ||
|
|
67016feb96 | ||
|
|
22bdfb7e56 | ||
|
|
ceb2c4f3ef | ||
|
|
d5a93a6400 | ||
|
|
01a2513812 | ||
|
|
8e7a752b2a | ||
|
|
999d3f1539 | ||
|
|
a7ee51e5d8 | ||
|
|
0e965b6529 | ||
|
|
a9db06f5e7 | ||
|
|
6827c4038b | ||
|
|
e8a6e90a61 | ||
|
|
ff956cb546 | ||
|
|
7d7e0f9800 | ||
|
|
3ae05a672d | ||
|
|
d700abff0a | ||
|
|
5267f34e76 | ||
|
|
d6e8290a1c | ||
|
|
36f66d40e5 | ||
|
|
5f12616cb9 | ||
|
|
bc43efba75 | ||
|
|
ef5f476cd6 | ||
|
|
98bf7710e4 | ||
|
|
7263af13ed | ||
|
|
d992a809f5 | ||
|
|
04f8d39860 | ||
|
|
b7bf14ab72 | ||
|
|
e8abbe0623 | ||
|
|
b14d59e977 | ||
|
|
5f12c17355 | ||
|
|
d170d78530 | ||
|
|
4d9160ca9f | ||
|
|
8f670f31b8 | ||
|
|
5838345f48 | ||
|
|
3f1c84f65a | ||
|
|
83b2b8fe60 | ||
|
|
ac24300274 | ||
|
|
2e657b7b12 | ||
|
|
c063617553 | ||
|
|
38a4f0234d | ||
|
|
740a723072 | ||
|
|
495cf58014 | ||
|
|
8e98759359 | ||
|
|
4ae0bb83f1 | ||
|
|
5459d812e7 | ||
|
|
831c222541 | ||
|
|
faad247d85 | ||
|
|
1e829ceaf3 | ||
|
|
79fe175440 | ||
|
|
9b32bfb3db | ||
|
|
37fea072bc | ||
|
|
31a603e905 | ||
|
|
ca21c285b0 | ||
|
|
5a3eaa85bf | ||
|
|
a5777683f3 | ||
|
|
90dd91c6cd | ||
|
|
8d8a8fe295 | ||
|
|
65e22bb76a | ||
|
|
f83ed19dfe | ||
|
|
53b14bde4d | ||
|
|
7742a5dac2 | ||
|
|
bddcb31fe2 | ||
|
|
b411a89703 | ||
|
|
e61752bd3a | ||
|
|
7a1d6fe509 | ||
|
|
4fd2743efa | ||
|
|
2c0eaaec3d | ||
|
|
3898fe3311 | ||
|
|
853b0e84cc | ||
|
|
42fe208eda | ||
|
|
444dc01931 | ||
|
|
95ce10f23b | ||
|
|
660fc3bb34 | ||
|
|
c71af7f610 | ||
|
|
ce476f2e5c | ||
|
|
a9fc85027d | ||
|
|
b2aa385942 | ||
|
|
424a7da470 | ||
|
|
b92504bebc | ||
|
|
e0846792d2 | ||
|
|
b9bf60ea23 | ||
|
|
bd27b4c162 | ||
|
|
28de676956 | ||
|
|
b3cde9900c | ||
|
|
2155bba5b0 | ||
|
|
a53fdc7126 | ||
|
|
3fc0ebdd51 | ||
|
|
211f416806 | ||
|
|
b90ad587c2 | ||
|
|
e7aecb89dd | ||
|
|
a45f8969a0 | ||
|
|
d3c06a3f76 | ||
|
|
f447ee7b9d | ||
|
|
3e168ce2ca | ||
|
|
c64edd2706 | ||
|
|
8a1f106c72 | ||
|
|
4ac99ffe0e | ||
|
|
fdcf87c70c | ||
|
|
5aabb83f5a | ||
|
|
bd678f9ca1 | ||
|
|
a87890b3cc | ||
|
|
86594851cb | ||
|
|
a83ccccffc | ||
|
|
50635e9c15 | ||
|
|
7d3dad3d1d | ||
|
|
dd22e78515 | ||
|
|
5df1cb0566 | ||
|
|
423df67042 | ||
|
|
568d5c46ed | ||
|
|
da25b91980 | ||
|
|
bc0dad6c1c | ||
|
|
9b8aa9b75d | ||
|
|
d5bc125617 | ||
|
|
4ffaabcc04 | ||
|
|
b597a0d31c | ||
|
|
fb32e5ca9a | ||
|
|
cd7ab6231f | ||
|
|
5908fd6552 | ||
|
|
6d2c6caa23 | ||
|
|
5eb00502ec | ||
|
|
3f9d6759d4 | ||
|
|
aba70207ab | ||
|
|
a8134a49c4 | ||
|
|
fa47f0c707 | ||
|
|
8501af298f | ||
|
|
5c7b1358d4 | ||
|
|
3938d8863e | ||
|
|
7838f9f3a3 | ||
|
|
de3c5751db | ||
|
|
5ee7e03c1b | ||
|
|
7a405b86c9 | ||
|
|
ffc3f33670 | ||
|
|
857055b797 | ||
|
|
d15ba3939d | ||
|
|
1ec83e4969 | ||
|
|
9275760599 | ||
|
|
d97d3ff5fc | ||
|
|
ea6734f550 | ||
|
|
f73751843f | ||
|
|
dbfbc56de7 | ||
|
|
70c5b23089 | ||
|
|
2ec6ffe478 | ||
|
|
793205afc5 | ||
|
|
c6b74daa0a | ||
|
|
23ce1fb1ba | ||
|
|
29188e0562 | ||
|
|
d9773c963f | ||
|
|
1206b1eb96 | ||
|
|
93af87a9e0 | ||
|
|
7a6970e570 | ||
|
|
ea584e94bd | ||
|
|
42b02b3a5f | ||
|
|
44f6a536d2 | ||
|
|
d7b8e071dd | ||
|
|
3f1aa1f9e2 | ||
|
|
82024a65cd |
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
@@ -27,18 +27,17 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Check Poetry lockfile
|
||||
run: |
|
||||
poetry check -C api --lock
|
||||
|
||||
4
.github/workflows/build-push.yml
vendored
4
.github/workflows/build-push.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "fix/external-knowledge-retrieval-issues"
|
||||
- "fix/redis-slow-in-gevent"
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-beta') }}
|
||||
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-') }}
|
||||
type=ref,event=branch
|
||||
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
|
||||
type=raw,value=${{ github.ref_name }},enable=${{ startsWith(github.ref, 'refs/tags/') }}
|
||||
|
||||
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
@@ -23,18 +23,17 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache-dependency-path: |
|
||||
api/pyproject.toml
|
||||
api/poetry.lock
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: poetry install -C api
|
||||
|
||||
|
||||
7
.github/workflows/style.yml
vendored
7
.github/workflows/style.yml
vendored
@@ -24,15 +24,16 @@ jobs:
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
- name: Install Poetry
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install Poetry
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: abatilo/actions-poetry@v3
|
||||
|
||||
- name: Python dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: poetry install -C api --only lint
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -175,6 +175,8 @@ docker/volumes/pgvector/data/*
|
||||
docker/volumes/pgvecto_rs/data/*
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
!docker/nginx/ssl/.gitkeep
|
||||
docker/middleware.env
|
||||
|
||||
sdks/python-client/build
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -6,8 +6,9 @@ Dify is licensed under the Apache License 2.0, with the following additional con
|
||||
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend.
|
||||
- Frontend Definition: For the purposes of this license, the "frontend" of Dify includes all components located in the `web/` directory when running Dify from the raw source code, or the "web" image when running Dify with Docker.
|
||||
|
||||
Please contact business@dify.ai by email to inquire about licensing matters.
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4GB
|
||||
>- RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
我们提供[ Dify 云服务](https://dify.ai),任何人都可以零设置尝试。它提供了自部署版本的所有功能,并在沙盒计划中包含 200 次免费的 GPT-4 调用。
|
||||
|
||||
- **自托管 Dify 社区版</br>**
|
||||
使用这个[入门指南](#quick-start)快速在您的环境中运行 Dify。
|
||||
使用这个[入门指南](#快速启动)快速在您的环境中运行 Dify。
|
||||
使用我们的[文档](https://docs.dify.ai)进行进一步的参考和更深入的说明。
|
||||
|
||||
- **面向企业/组织的 Dify</br>**
|
||||
@@ -174,7 +174,7 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI
|
||||
在安装 Dify 之前,请确保您的机器满足以下最低系统要求:
|
||||
|
||||
- CPU >= 2 Core
|
||||
- RAM >= 4GB
|
||||
- RAM >= 4 GiB
|
||||
|
||||
### 快速启动
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
@@ -39,7 +42,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos, baidu-obs
|
||||
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@@ -99,11 +102,16 @@ VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# Supabase Storage Configuration
|
||||
SUPABASE_BUCKET_NAME=your-bucket-name
|
||||
SUPABASE_API_KEY=your-access-key
|
||||
SUPABASE_URL=your-server-url
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
|
||||
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb, upstash
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -203,14 +211,39 @@ OPENSEARCH_USER=admin
|
||||
OPENSEARCH_PASSWORD=admin
|
||||
OPENSEARCH_SECURE=true
|
||||
|
||||
# Baidu configuration
|
||||
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
|
||||
BAIDU_VECTOR_DB_ACCOUNT=root
|
||||
BAIDU_VECTOR_DB_API_KEY=dify
|
||||
BAIDU_VECTOR_DB_DATABASE=dify
|
||||
BAIDU_VECTOR_DB_SHARD=1
|
||||
BAIDU_VECTOR_DB_REPLICAS=3
|
||||
|
||||
# Upstash configuration
|
||||
UPSTASH_VECTOR_URL=your-server-url
|
||||
UPSTASH_VECTOR_TOKEN=your-access-token
|
||||
|
||||
# ViKingDB configuration
|
||||
VIKINGDB_ACCESS_KEY=your-ak
|
||||
VIKINGDB_SECRET_KEY=your-sk
|
||||
VIKINGDB_REGION=cn-shanghai
|
||||
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
|
||||
VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
@@ -276,6 +309,10 @@ RESPECT_XFORWARD_HEADERS_ENABLED=false
|
||||
|
||||
# Log file path
|
||||
LOG_FILE=
|
||||
# Log file max size, the unit is MB
|
||||
LOG_FILE_MAX_SIZE=20
|
||||
# Log file max backup count
|
||||
LOG_FILE_BACKUP_COUNT=5
|
||||
|
||||
# Indexing configuration
|
||||
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
@@ -284,6 +321,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
@@ -301,3 +339,6 @@ POSITION_TOOL_EXCLUDES=
|
||||
POSITION_PROVIDER_PINS=
|
||||
POSITION_PROVIDER_INCLUDES=
|
||||
POSITION_PROVIDER_EXCLUDES=
|
||||
|
||||
# Reset password token expiry minutes
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
|
||||
|
||||
15
api/.vscode/launch.json.example
vendored
15
api/.vscode/launch.json.example
vendored
@@ -1,8 +1,15 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Launch Flask and Celery",
|
||||
"configurations": ["Python: Flask", "Python: Celery"]
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask",
|
||||
"consoleName": "Flask",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
@@ -17,12 +24,12 @@
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery",
|
||||
"consoleName": "Celery",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"python": "${workspaceFolder}/.venv/bin/python",
|
||||
@@ -45,10 +52,10 @@
|
||||
"-c",
|
||||
"1",
|
||||
"--loglevel",
|
||||
"info",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace,app_deletion"
|
||||
]
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.1-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -85,3 +85,4 @@
|
||||
cd ../
|
||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||
```
|
||||
|
||||
|
||||
220
api/app.py
220
api/app.py
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
@@ -10,44 +12,20 @@ if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from flask import Response
|
||||
|
||||
import contexts
|
||||
from commands import register_commands
|
||||
from configs import dify_config
|
||||
from app_factory import create_app
|
||||
|
||||
# DO NOT REMOVE BELOW
|
||||
from events import event_handlers # noqa: F401
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_sentry,
|
||||
ext_storage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
|
||||
# TODO: Find a way to avoid importing models here
|
||||
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401
|
||||
from services.account_service import AccountService
|
||||
|
||||
# DO NOT REMOVE ABOVE
|
||||
|
||||
@@ -60,193 +38,11 @@ if hasattr(time, "tzset"):
|
||||
time.tzset()
|
||||
|
||||
|
||||
class DifyApp(Flask):
|
||||
pass
|
||||
|
||||
|
||||
# -------------
|
||||
# Configuration
|
||||
# -------------
|
||||
|
||||
|
||||
config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
# ----------------------------
|
||||
|
||||
|
||||
def create_flask_app_with_configs() -> Flask:
|
||||
"""
|
||||
create a raw flask app
|
||||
with configs loaded from .env file
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
|
||||
app.secret_key = app.config["SECRET_KEY"]
|
||||
|
||||
log_handlers = None
|
||||
log_file = app.config.get("LOG_FILE")
|
||||
if log_file:
|
||||
log_dir = os.path.dirname(log_file)
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_handlers = [
|
||||
RotatingFileHandler(
|
||||
filename=log_file,
|
||||
maxBytes=1024 * 1024 * 1024,
|
||||
backupCount=5,
|
||||
),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
]
|
||||
|
||||
logging.basicConfig(
|
||||
level=app.config.get("LOG_LEVEL"),
|
||||
format=app.config.get("LOG_FORMAT"),
|
||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||
handlers=log_handlers,
|
||||
force=True,
|
||||
)
|
||||
log_tz = app.config.get("LOG_TZ")
|
||||
if log_tz:
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
timezone = pytz.timezone(log_tz)
|
||||
|
||||
def time_converter(seconds):
|
||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
handler.formatter.converter = time_converter
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
ext_redis.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_hosting_provider.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_proxy_fix.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
|
||||
if logged_in_account:
|
||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||
return logged_in_account
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# register blueprint routers
|
||||
def register_blueprints(app):
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if app.config.get("TESTING"):
|
||||
if dify_config.TESTING:
|
||||
print("App is running in TESTING mode")
|
||||
|
||||
|
||||
@@ -254,15 +50,15 @@ if app.config.get("TESTING"):
|
||||
def after_request(response):
|
||||
"""Add Version headers to the response."""
|
||||
response.set_cookie("remember_token", "", expires=0)
|
||||
response.headers.add("X-Version", app.config["CURRENT_VERSION"])
|
||||
response.headers.add("X-Env", app.config["DEPLOY_ENV"])
|
||||
response.headers.add("X-Version", dify_config.CURRENT_VERSION)
|
||||
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return Response(
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": app.config["CURRENT_VERSION"]}),
|
||||
json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
|
||||
status=200,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
176
api/app_factory.py
Normal file
176
api/app_factory.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
|
||||
if os.environ.get("DEBUG", "false").lower() != "true":
|
||||
from gevent import monkey
|
||||
|
||||
monkey.patch_all()
|
||||
|
||||
import grpc.experimental.gevent
|
||||
|
||||
grpc.experimental.gevent.init_gevent()
|
||||
|
||||
import json
|
||||
|
||||
from flask import Flask, Response, request
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import contexts
|
||||
from commands import register_commands
|
||||
from configs import dify_config
|
||||
from extensions import (
|
||||
ext_celery,
|
||||
ext_code_based_extension,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_hosting_provider,
|
||||
ext_logging,
|
||||
ext_login,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_proxy_fix,
|
||||
ext_redis,
|
||||
ext_sentry,
|
||||
ext_storage,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_login import login_manager
|
||||
from libs.passport import PassportService
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class DifyApp(Flask):
|
||||
pass
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
# ----------------------------
|
||||
def create_flask_app_with_configs() -> Flask:
|
||||
"""
|
||||
create a raw flask app
|
||||
with configs loaded from .env file
|
||||
"""
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# populate configs into system environment variables
|
||||
for key, value in dify_app.config.items():
|
||||
if isinstance(value, str):
|
||||
os.environ[key] = value
|
||||
elif isinstance(value, int | float | bool):
|
||||
os.environ[key] = str(value)
|
||||
elif value is None:
|
||||
os.environ[key] = ""
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> Flask:
|
||||
app = create_flask_app_with_configs()
|
||||
app.secret_key = dify_config.SECRET_KEY
|
||||
initialize_extensions(app)
|
||||
register_blueprints(app)
|
||||
register_commands(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app):
|
||||
# Since the application instance is now created, pass it to each Flask
|
||||
# extension instance to bind it to the Flask application instance (app)
|
||||
ext_logging.init_app(app)
|
||||
ext_compress.init_app(app)
|
||||
ext_code_based_extension.init()
|
||||
ext_database.init_app(app)
|
||||
ext_migrate.init(app, db)
|
||||
ext_redis.init_app(app)
|
||||
ext_storage.init_app(app)
|
||||
ext_celery.init_app(app)
|
||||
ext_login.init_app(app)
|
||||
ext_mail.init_app(app)
|
||||
ext_hosting_provider.init_app(app)
|
||||
ext_sentry.init_app(app)
|
||||
ext_proxy_fix.init_app(app)
|
||||
|
||||
|
||||
# Flask-Login configuration
|
||||
@login_manager.request_loader
|
||||
def load_user_from_request(request_from_flask_login):
|
||||
"""Load user based on the request."""
|
||||
if request.blueprint not in {"console", "inner_api"}:
|
||||
return None
|
||||
# Check if the user_id contains a dot, indicating the old format
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header:
|
||||
auth_token = request.args.get("_token")
|
||||
if not auth_token:
|
||||
raise Unauthorized("Invalid Authorization token.")
|
||||
else:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
decoded = PassportService().verify(auth_token)
|
||||
user_id = decoded.get("user_id")
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if logged_in_account:
|
||||
contexts.tenant_id.set(logged_in_account.current_tenant_id)
|
||||
return logged_in_account
|
||||
|
||||
|
||||
@login_manager.unauthorized_handler
|
||||
def unauthorized_handler():
|
||||
"""Handle unauthorized requests."""
|
||||
return Response(
|
||||
json.dumps({"code": "unauthorized", "message": "Unauthorized."}),
|
||||
status=401,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# register blueprint routers
|
||||
def register_blueprints(app):
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
CORS(
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(web_bp)
|
||||
|
||||
CORS(
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"])
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
app.register_blueprint(inner_api_bp)
|
||||
@@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models import Tenant
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
@@ -259,6 +259,26 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = 0
|
||||
total_count = 0
|
||||
vector_type = dify_config.VECTOR_STORE
|
||||
upper_colletion_vector_types = {
|
||||
VectorType.MILVUS,
|
||||
VectorType.PGVECTOR,
|
||||
VectorType.RELYT,
|
||||
VectorType.WEAVIATE,
|
||||
VectorType.ORACLE,
|
||||
VectorType.ELASTICSEARCH,
|
||||
}
|
||||
lower_colletion_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
VectorType.CHROMA,
|
||||
VectorType.MYSCALE,
|
||||
VectorType.PGVECTO_RS,
|
||||
VectorType.TIDB_VECTOR,
|
||||
VectorType.OPENSEARCH,
|
||||
VectorType.TENCENT,
|
||||
VectorType.BAIDU,
|
||||
VectorType.VIKINGDB,
|
||||
VectorType.UPSTASH,
|
||||
}
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
@@ -284,11 +304,9 @@ def migrate_knowledge_vector_database():
|
||||
skipped_count = skipped_count + 1
|
||||
continue
|
||||
collection_name = ""
|
||||
if vector_type == VectorType.WEAVIATE:
|
||||
dataset_id = dataset.id
|
||||
dataset_id = dataset.id
|
||||
if vector_type in upper_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.QDRANT:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
@@ -301,55 +319,15 @@ def migrate_knowledge_vector_database():
|
||||
else:
|
||||
raise ValueError("Dataset Collection Binding not found")
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == VectorType.MILVUS:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.RELYT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.TENCENT:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.PGVECTOR:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.OPENSEARCH:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.OPENSEARCH,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ANALYTICDB:
|
||||
dataset_id = dataset.id
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {
|
||||
"type": VectorType.ANALYTICDB,
|
||||
"vector_store": {"class_prefix": collection_name},
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == VectorType.ELASTICSEARCH:
|
||||
dataset_id = dataset.id
|
||||
index_name = Dataset.gen_collection_name_by_id(dataset_id)
|
||||
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type in lower_colletion_vector_types:
|
||||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
|
||||
else:
|
||||
raise ValueError(f"Vector store {vector_type} is not supported.")
|
||||
|
||||
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"Migrating dataset {dataset.id}.")
|
||||
|
||||
@@ -449,14 +427,14 @@ def convert_to_agent_apps():
|
||||
# fetch first 1000 apps
|
||||
sql_query = """SELECT a.id AS id FROM apps a
|
||||
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
||||
WHERE a.mode = 'chat'
|
||||
AND am.agent_mode is not null
|
||||
WHERE a.mode = 'chat'
|
||||
AND am.agent_mode is not null
|
||||
AND (
|
||||
am.agent_mode like '%"strategy": "function_call"%'
|
||||
am.agent_mode like '%"strategy": "function_call"%'
|
||||
OR am.agent_mode like '%"strategy": "react"%'
|
||||
)
|
||||
)
|
||||
AND (
|
||||
am.agent_mode like '{"enabled": true%'
|
||||
am.agent_mode like '{"enabled": true%'
|
||||
OR am.agent_mode like '{"max_iteration": %'
|
||||
) ORDER BY a.created_at DESC LIMIT 1000
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
Field,
|
||||
HttpUrl,
|
||||
NegativeInt,
|
||||
NonNegativeInt,
|
||||
PositiveFloat,
|
||||
PositiveInt,
|
||||
computed_field,
|
||||
)
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.feature.hosted_service import HostedServiceConfig
|
||||
@@ -11,16 +20,31 @@ class SecurityConfig(BaseSettings):
|
||||
Security-related configurations for the application
|
||||
"""
|
||||
|
||||
SECRET_KEY: Optional[str] = Field(
|
||||
SECRET_KEY: str = Field(
|
||||
description="Secret key for secure session cookie signing."
|
||||
"Make sure you are changing this key for your deployment with a strong key."
|
||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||
default=None,
|
||||
default="",
|
||||
)
|
||||
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||
description="Duration in hours for which a password reset token remains valid",
|
||||
default=24,
|
||||
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="Duration in minutes for which a password reset token remains valid",
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY_ENABLE: bool = Field(
|
||||
description="Whether to enable admin api key for authentication",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ADMIN_API_KEY: Optional[str] = Field(
|
||||
description="admin api key for authentication",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -177,6 +201,16 @@ class FileUploadConfig(BaseSettings):
|
||||
default=10,
|
||||
)
|
||||
|
||||
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="video file size limit in Megabytes for uploading files",
|
||||
default=100,
|
||||
)
|
||||
|
||||
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||
description="audio file size limit in Megabytes for uploading files",
|
||||
default=50,
|
||||
)
|
||||
|
||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||
description="Maximum number of files allowed in a batch upload operation",
|
||||
default=20,
|
||||
@@ -285,6 +319,16 @@ class LoggingConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
LOG_FILE_MAX_SIZE: PositiveInt = Field(
|
||||
description="Maximum file size for file rotation retention, the unit is megabytes (MB)",
|
||||
default=20,
|
||||
)
|
||||
|
||||
LOG_FILE_BACKUP_COUNT: PositiveInt = Field(
|
||||
description="Maximum file backup count file rotation retention",
|
||||
default=5,
|
||||
)
|
||||
|
||||
LOG_FORMAT: str = Field(
|
||||
description="Format string for log messages",
|
||||
default="%(asctime)s.%(msecs)03d %(levelname)s [%(threadName)s] [%(filename)s:%(lineno)d] - %(message)s",
|
||||
@@ -355,14 +399,14 @@ class WorkflowConfig(BaseSettings):
|
||||
)
|
||||
|
||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
|
||||
default=5 * 1024,
|
||||
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||
default=200 * 1024,
|
||||
)
|
||||
|
||||
|
||||
class OAuthConfig(BaseSettings):
|
||||
class AuthConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for OAuth authentication
|
||||
Configuration for authentication and OAuth
|
||||
"""
|
||||
|
||||
OAUTH_REDIRECT_PATH: str = Field(
|
||||
@@ -371,7 +415,7 @@ class OAuthConfig(BaseSettings):
|
||||
)
|
||||
|
||||
GITHUB_CLIENT_ID: Optional[str] = Field(
|
||||
description="GitHub OAuth client secret",
|
||||
description="GitHub OAuth client ID",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -390,6 +434,11 @@ class OAuthConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field(
|
||||
description="Expiration time for access tokens in minutes",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@@ -468,12 +517,18 @@ class MailConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
|
||||
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
|
||||
default=50,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for RAG ETL processes
|
||||
"""
|
||||
|
||||
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
|
||||
ETL_TYPE: str = Field(
|
||||
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
||||
default="dify",
|
||||
@@ -501,11 +556,16 @@ class DataSetConfig(BaseSettings):
|
||||
Configuration for dataset management
|
||||
"""
|
||||
|
||||
CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations",
|
||||
PLAN_SANDBOX_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations - plan: sandbox",
|
||||
default=30,
|
||||
)
|
||||
|
||||
PLAN_PRO_CLEAN_DAY_SETTING: PositiveInt = Field(
|
||||
description="Interval in days for dataset cleanup operations - plan: pro and team",
|
||||
default=7,
|
||||
)
|
||||
|
||||
DATASET_OPERATOR_ENABLED: bool = Field(
|
||||
description="Enable or disable dataset operator functionality",
|
||||
default=False,
|
||||
@@ -535,7 +595,7 @@ class IndexingConfig(BaseSettings):
|
||||
|
||||
|
||||
class ImageFormatConfig(BaseSettings):
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||
default="base64",
|
||||
)
|
||||
@@ -604,9 +664,37 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_EMAIL_PASSWORD_LOGIN: bool = Field(
|
||||
description="whether to enable email password login",
|
||||
default=True,
|
||||
)
|
||||
ENABLE_SOCIAL_OAUTH_LOGIN: bool = Field(
|
||||
description="whether to enable github/google oauth login",
|
||||
default=False,
|
||||
)
|
||||
EMAIL_CODE_LOGIN_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
|
||||
description="expiry time in minutes for email code login token",
|
||||
default=5,
|
||||
)
|
||||
ALLOW_REGISTER: bool = Field(
|
||||
description="whether to enable register",
|
||||
default=False,
|
||||
)
|
||||
ALLOW_CREATE_WORKSPACE: bool = Field(
|
||||
description="whether to enable create workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
AuthConfig, # Changed from OAuthConfig to AuthConfig
|
||||
BillingConfig,
|
||||
CodeExecutionSandboxConfig,
|
||||
DataSetConfig,
|
||||
@@ -621,14 +709,14 @@ class FeatureConfig(
|
||||
MailConfig,
|
||||
ModelLoadBalanceConfig,
|
||||
ModerationConfig,
|
||||
OAuthConfig,
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
SecurityConfig,
|
||||
ToolConfig,
|
||||
UpdateConfig,
|
||||
WorkflowConfig,
|
||||
WorkspaceConfig,
|
||||
PositionConfig,
|
||||
LoginConfig,
|
||||
# hosted services config
|
||||
HostedServiceConfig,
|
||||
CeleryBeatConfig,
|
||||
|
||||
@@ -12,6 +12,7 @@ from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageC
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
@@ -27,13 +28,16 @@ from configs.middleware.vdb.qdrant_config import QdrantConfig
|
||||
from configs.middleware.vdb.relyt_config import RelytConfig
|
||||
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
|
||||
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
|
||||
from configs.middleware.vdb.upstash_config import UpstashConfig
|
||||
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
|
||||
from configs.middleware.vdb.weaviate_config import WeaviateConfig
|
||||
|
||||
|
||||
class StorageConfig(BaseSettings):
|
||||
STORAGE_TYPE: str = Field(
|
||||
description="Type of storage to use."
|
||||
" Options: 'local', 's3', 'azure-blob', 'aliyun-oss', 'google-storage'. Default is 'local'.",
|
||||
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
|
||||
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
|
||||
default="local",
|
||||
)
|
||||
|
||||
@@ -222,6 +226,7 @@ class MiddlewareConfig(
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
OCIStorageConfig,
|
||||
S3StorageConfig,
|
||||
SupabaseStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
@@ -241,5 +246,7 @@ class MiddlewareConfig(
|
||||
WeaviateConfig,
|
||||
ElasticsearchConfig,
|
||||
InternalTestConfig,
|
||||
VikingDBConfig,
|
||||
UpstashConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
5
api/configs/middleware/cache/redis_config.py
vendored
5
api/configs/middleware/cache/redis_config.py
vendored
@@ -34,6 +34,11 @@ class RedisConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
REDIS_MAX_CONNECTIONS: PositiveInt = Field(
|
||||
description="Maximum number of connections to Redis",
|
||||
default=200,
|
||||
)
|
||||
|
||||
REDIS_USE_SSL: bool = Field(
|
||||
description="Enable SSL/TLS for the Redis connection",
|
||||
default=False,
|
||||
|
||||
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
24
api/configs/middleware/storage/supabase_storage_config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SupabaseStorageConfig(BaseModel):
|
||||
"""
|
||||
Configuration settings for Supabase Object Storage Service
|
||||
"""
|
||||
|
||||
SUPABASE_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Name of the Supabase bucket to store and retrieve objects (e.g., 'dify-bucket')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_API_KEY: Optional[str] = Field(
|
||||
description="API KEY for authenticating with Supabase",
|
||||
default=None,
|
||||
)
|
||||
|
||||
SUPABASE_URL: Optional[str] = Field(
|
||||
description="URL of the Supabase",
|
||||
default=None,
|
||||
)
|
||||
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
45
api/configs/middleware/vdb/baidu_vector_config.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class BaiduVectorDBConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Baidu Vector Database
|
||||
"""
|
||||
|
||||
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
|
||||
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
|
||||
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
|
||||
default=30000,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
|
||||
description="Account for authenticating with the Baidu Vector Database",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
|
||||
description="API key for authenticating with the Baidu Vector Database service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
|
||||
description="Name of the specific Baidu Vector Database to connect to",
|
||||
default=None,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
|
||||
description="Number of shards for the Baidu Vector Database (default is 1)",
|
||||
default=1,
|
||||
)
|
||||
|
||||
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
|
||||
description="Number of replicas for the Baidu Vector Database (default is 3)",
|
||||
default=3,
|
||||
)
|
||||
@@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
||||
ORACLE_PORT: PositiveInt = Field(
|
||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||
default=1521,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
||||
PGVECTOR_PORT: PositiveInt = Field(
|
||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||
default=5433,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
||||
PGVECTO_RS_PORT: PositiveInt = Field(
|
||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||
default=5431,
|
||||
)
|
||||
|
||||
20
api/configs/middleware/vdb/upstash_config.py
Normal file
20
api/configs/middleware/vdb/upstash_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class UpstashConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Upstash vector database
|
||||
"""
|
||||
|
||||
UPSTASH_VECTOR_URL: Optional[str] = Field(
|
||||
description="URL of the upstash server (e.g., 'https://vector.upstash.io')",
|
||||
default=None,
|
||||
)
|
||||
|
||||
UPSTASH_VECTOR_TOKEN: Optional[str] = Field(
|
||||
description="Token for authenticating with the upstash server",
|
||||
default=None,
|
||||
)
|
||||
49
api/configs/middleware/vdb/vikingdb_config.py
Normal file
49
api/configs/middleware/vdb/vikingdb_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VikingDBConfig(BaseModel):
|
||||
"""
|
||||
Configuration for connecting to Volcengine VikingDB.
|
||||
Refer to the following documentation for details on obtaining credentials:
|
||||
https://www.volcengine.com/docs/6291/65568
|
||||
"""
|
||||
|
||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||
description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||
"Refer to the following documentation for details on obtaining credentials:"
|
||||
"https://www.volcengine.com/docs/6291/65568",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VIKINGDB_REGION: str = Field(
|
||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||
default="cn-shanghai",
|
||||
)
|
||||
|
||||
VIKINGDB_HOST: str = Field(
|
||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||
)
|
||||
|
||||
VIKINGDB_SCHEME: str = Field(
|
||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||
default="http",
|
||||
)
|
||||
|
||||
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
|
||||
description="The connection timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
|
||||
VIKINGDB_SOCKET_TIMEOUT: int = Field(
|
||||
description="The socket timeout of the Volcengine VikingDB service.",
|
||||
default=30,
|
||||
)
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.9.1-fix1",
|
||||
default="0.10.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -1,2 +1,24 @@
|
||||
from configs import dify_config
|
||||
|
||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||
|
||||
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
DOCUMENT_EXTENSIONS.append("ppt")
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
else:
|
||||
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
@@ -15,7 +15,7 @@ from models.model import App, InstalledApp, RecommendedApp
|
||||
def admin_required(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not os.getenv("ADMIN_API_KEY"):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
@@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if os.getenv("ADMIN_API_KEY") != auth_token:
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -22,7 +22,8 @@ from fields.conversation_fields import (
|
||||
)
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class CompletionConversationApi(Resource):
|
||||
|
||||
@@ -52,4 +52,39 @@ class RuleGenerateApi(Resource):
|
||||
return rules
|
||||
|
||||
|
||||
class RuleCodeGenerateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("no_variable", type=bool, required=True, default=False, location="json")
|
||||
parser.add_argument("code_language", type=str, required=False, default="javascript", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
account = current_user
|
||||
CODE_GENERATION_MAX_TOKENS = int(os.getenv("CODE_GENERATION_MAX_TOKENS", "1024"))
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=account.current_tenant_id,
|
||||
instruction=args["instruction"],
|
||||
model_config=args["model_config"],
|
||||
code_language=args["code_language"],
|
||||
max_tokens=CODE_GENERATION_MAX_TOKENS,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
return code_result
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
|
||||
@@ -105,6 +105,8 @@ class ChatMessageListApi(Resource):
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_site_fields
|
||||
from libs.login import login_required
|
||||
from models.model import Site
|
||||
from models import Site
|
||||
|
||||
|
||||
def parse_app_site_args():
|
||||
|
||||
@@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.segments import factory
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
@@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
environment_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
conversation_variables = [
|
||||
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
@@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
||||
)
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
class WorkflowDailyRunsStatistic(Resource):
|
||||
|
||||
@@ -5,7 +5,8 @@ from typing import Optional, Union
|
||||
from controllers.console.app.error import AppNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.model import App, AppMode
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import base64
|
||||
import datetime
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
from libs.helper import StrLen, email, extract_remote_ip, timezone
|
||||
from models.account import AccountStatus, Tenant
|
||||
from services.account_service import AccountService, RegisterService
|
||||
|
||||
|
||||
class ActivateCheckApi(Resource):
|
||||
@@ -27,8 +25,18 @@ class ActivateCheckApi(Resource):
|
||||
token = args["token"]
|
||||
|
||||
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
|
||||
|
||||
return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None}
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant: Tenant = invitation.get("tenant", None)
|
||||
workspace_name = tenant.name if tenant else None
|
||||
workspace_id = tenant.id if tenant else None
|
||||
invitee_email = data.get("email") if data else None
|
||||
return {
|
||||
"is_valid": invitation is not None,
|
||||
"data": {"workspace_name": workspace_name, "workspace_id": workspace_id, "email": invitee_email},
|
||||
}
|
||||
else:
|
||||
return {"is_valid": False}
|
||||
|
||||
|
||||
class ActivateApi(Resource):
|
||||
@@ -38,7 +46,6 @@ class ActivateApi(Resource):
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
)
|
||||
@@ -54,15 +61,6 @@ class ActivateApi(Resource):
|
||||
account = invitation["account"]
|
||||
account.name = args["name"]
|
||||
|
||||
# generate password salt
|
||||
salt = secrets.token_bytes(16)
|
||||
base64_salt = base64.b64encode(salt).decode()
|
||||
|
||||
# encrypt password with salt
|
||||
password_hashed = hash_password(args["password"], salt)
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
account.interface_language = args["interface_language"]
|
||||
account.timezone = args["timezone"]
|
||||
account.interface_theme = "light"
|
||||
@@ -70,7 +68,9 @@ class ActivateApi(Resource):
|
||||
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
api.add_resource(ActivateCheckApi, "/activate/check")
|
||||
|
||||
@@ -27,5 +27,29 @@ class InvalidTokenError(BaseHTTPException):
|
||||
|
||||
class PasswordResetRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "password_reset_rate_limit_exceeded"
|
||||
description = "Password reset rate limit exceeded. Try again later."
|
||||
description = "Too many password reset emails have been sent. Please try again in 1 minutes."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailCodeError(BaseHTTPException):
|
||||
error_code = "email_code_error"
|
||||
description = "Email code is invalid or expired."
|
||||
code = 400
|
||||
|
||||
|
||||
class EmailOrPasswordMismatchError(BaseHTTPException):
|
||||
error_code = "email_or_password_mismatch"
|
||||
description = "The email or password is mismatched."
|
||||
code = 400
|
||||
|
||||
|
||||
class EmailPasswordLoginLimitError(BaseHTTPException):
|
||||
error_code = "email_code_login_limit"
|
||||
description = "Too many incorrect password attempts. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class EmailCodeLoginRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "email_code_login_rate_limit_exceeded"
|
||||
description = "Too many login emails have been sent. Please try again in 5 minutes."
|
||||
code = 429
|
||||
|
||||
@@ -1,65 +1,82 @@
|
||||
import base64
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
PasswordResetRateLimitExceededError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, NotAllowedRegister
|
||||
from controllers.console.setup import setup_required
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import RateLimitExceededError
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
email = args["email"]
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if not email_validate(email):
|
||||
raise InvalidEmailError()
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if account:
|
||||
try:
|
||||
AccountService.send_reset_password_email(account=account)
|
||||
except RateLimitExceededError:
|
||||
logging.warning(f"Rate limit exceeded for email: {account.email}")
|
||||
raise PasswordResetRateLimitExceededError()
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
# Return success to avoid revealing email registration status
|
||||
logging.warning(f"Attempt to reset password for unregistered email: {email}")
|
||||
language = "en-US"
|
||||
|
||||
return {"result": "success"}
|
||||
account = Account.query.filter_by(email=args["email"]).first()
|
||||
token = None
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
token = args["token"]
|
||||
|
||||
reset_data = AccountService.get_reset_password_data(token)
|
||||
user_email = args["email"]
|
||||
|
||||
if reset_data is None:
|
||||
return {"is_valid": False, "email": None}
|
||||
return {"is_valid": True, "email": reset_data.get("email")}
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
raise EmailCodeError()
|
||||
|
||||
return {"is_valid": True, "email": token_data.get("email")}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@@ -92,9 +109,26 @@ class ForgotPasswordResetApi(Resource):
|
||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||
|
||||
account = Account.query.filter_by(email=reset_data.get("email")).first()
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
if account:
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant and not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
else:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=reset_data.get("email"),
|
||||
name=reset_data.get("email"),
|
||||
password=password_confirm,
|
||||
interface_language=languages[0],
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
pass
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@@ -5,12 +5,29 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
import services
|
||||
from constants.languages import languages
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailOrPasswordMismatchError,
|
||||
EmailPasswordLoginLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
)
|
||||
from controllers.console.error import (
|
||||
AccountBannedError,
|
||||
EmailSendIpLimitError,
|
||||
NotAllowedCreateWorkspace,
|
||||
NotAllowedRegister,
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from libs.helper import email, get_remote_ip
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class LoginApi(Resource):
|
||||
@@ -23,15 +40,43 @@ class LoginApi(Resource):
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
parser.add_argument("remember_me", type=bool, required=False, default=False, location="json")
|
||||
parser.add_argument("invite_token", type=str, required=False, default=None, location="json")
|
||||
parser.add_argument("language", type=str, required=False, default="en-US", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# todo: Verify the recaptcha
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args["email"])
|
||||
if is_login_error_rate_limit:
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invitation = args["invite_token"]
|
||||
if invitation:
|
||||
invitation = RegisterService.get_invitation_if_token_valid(None, args["email"], invitation)
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
try:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError as e:
|
||||
return {"code": "unauthorized", "message": str(e)}, 401
|
||||
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
invitee_email = data.get("email") if data else None
|
||||
if invitee_email != args["email"]:
|
||||
raise InvalidEmailError()
|
||||
account = AccountService.authenticate(args["email"], args["password"], args["invite_token"])
|
||||
else:
|
||||
account = AccountService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
AccountService.add_login_error_rate_limit(args["email"])
|
||||
raise EmailOrPasswordMismatchError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
return {"result": "fail", "data": token, "code": "account_not_found"}
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
if len(tenants) == 0:
|
||||
@@ -40,71 +85,138 @@ class LoginApi(Resource):
|
||||
"data": "workspace not found, please contact system admin to invite you to join in a workspace",
|
||||
}
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
account = cast(Account, flask_login.current_user)
|
||||
token = request.headers.get("Authorization", "").split(" ")[1]
|
||||
AccountService.logout(account=account, token=token)
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
return {"result": "success"}
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ResetPasswordApi(Resource):
|
||||
class ResetPasswordSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
# parser = reqparse.RequestParser()
|
||||
# parser.add_argument('email', type=email, required=True, location='json')
|
||||
# args = parser.parse_args()
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# import mailchimp_transactional as MailchimpTransactional
|
||||
# from mailchimp_transactional.api_client import ApiClientError
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
# account = {'email': args['email']}
|
||||
# account = AccountService.get_by_email(args['email'])
|
||||
# if account is None:
|
||||
# raise ValueError('Email not found')
|
||||
# new_password = AccountService.generate_password()
|
||||
# AccountService.update_password(account, new_password)
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_reset_password_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, language=language)
|
||||
|
||||
# todo: Send email
|
||||
# MAILCHIMP_API_KEY = dify_config.MAILCHIMP_TRANSACTIONAL_API_KEY
|
||||
# mailchimp = MailchimpTransactional(MAILCHIMP_API_KEY)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
# message = {
|
||||
# 'from_email': 'noreply@example.com',
|
||||
# 'to': [{'email': account['email']}],
|
||||
# 'subject': 'Reset your Dify password',
|
||||
# 'html': """
|
||||
# <p>Dear User,</p>
|
||||
# <p>The Dify team has generated a new password for you, details as follows:</p>
|
||||
# <p><strong>{new_password}</strong></p>
|
||||
# <p>Please change your password to log in as soon as possible.</p>
|
||||
# <p>Regards,</p>
|
||||
# <p>The Dify Team</p>
|
||||
# """
|
||||
# }
|
||||
|
||||
# response = mailchimp.messages.send({
|
||||
# 'message': message,
|
||||
# # required for transactional email
|
||||
# ' settings': {
|
||||
# 'sandbox_mode': dify_config.MAILCHIMP_SANDBOX_MODE,
|
||||
# },
|
||||
# })
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check if MSG was sent
|
||||
# if response.status_code != 200:
|
||||
# # handle error
|
||||
# pass
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
return {"result": "success"}
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
account = AccountService.get_user_through_email(args["email"])
|
||||
if account is None:
|
||||
if FeatureService.get_system_features().is_allow_register:
|
||||
token = AccountService.send_email_code_login_email(email=args["email"], language=language)
|
||||
else:
|
||||
raise NotAllowedRegister()
|
||||
else:
|
||||
token = AccountService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if token_data["email"] != args["email"]:
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args["code"]:
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args["token"])
|
||||
account = AccountService.get_user_through_email(user_email)
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
else:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
if account is None:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
email=user_email, name=user_email, interface_language=languages[0]
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return NotAllowedCreateWorkspace()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
api.add_resource(ResetPasswordSendEmailApi, "/reset-password")
|
||||
api.add_resource(RefreshTokenApi, "/refresh-token")
|
||||
|
||||
@@ -5,14 +5,20 @@ from typing import Optional
|
||||
import requests
|
||||
from flask import current_app, redirect, request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import get_remote_ip
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from models.account import Account, AccountStatus
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.errors.account import AccountNotFoundError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkSpaceNotFoundError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
from .. import api
|
||||
|
||||
@@ -42,6 +48,7 @@ def get_oauth_providers():
|
||||
|
||||
class OAuthLogin(Resource):
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
@@ -49,7 +56,7 @@ class OAuthLogin(Resource):
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url()
|
||||
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@@ -62,6 +69,11 @@ class OAuthCallback(Resource):
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
invite_token = None
|
||||
if state:
|
||||
invite_token = state
|
||||
|
||||
try:
|
||||
token = oauth_provider.get_access_token(code)
|
||||
user_info = oauth_provider.get_user_info(token)
|
||||
@@ -69,21 +81,52 @@ class OAuthCallback(Resource):
|
||||
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
|
||||
return {"error": "OAuth process failed"}, 400
|
||||
|
||||
account = _generate_account(provider, user_info)
|
||||
if invite_token and RegisterService.is_valid_invite_token(invite_token):
|
||||
invitation = RegisterService._get_invitation_by_token(token=invite_token)
|
||||
if invitation:
|
||||
invitation_email = invitation.get("email", None)
|
||||
if invitation_email != user_info.email:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account = _generate_account(provider, user_info)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
|
||||
# Check account status
|
||||
if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}:
|
||||
return {"error": "Account is banned or closed."}, 403
|
||||
if account.status == AccountStatus.BANNED.value:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account is banned.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
except Unauthorized:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Workspace not found.")
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
"?message=Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
)
|
||||
|
||||
token = AccountService.login(account, ip_address=get_remote_ip(request))
|
||||
token_pair = AccountService.login(
|
||||
account=account,
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}")
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]:
|
||||
@@ -99,8 +142,20 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
|
||||
if account:
|
||||
tenant = TenantService.get_join_tenants(account)
|
||||
if not tenant:
|
||||
if not FeatureService.get_system_features().is_allow_create_workspace:
|
||||
raise WorkSpaceNotAllowedCreateError()
|
||||
else:
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
if not account:
|
||||
# Create account
|
||||
if not FeatureService.get_system_features().is_allow_register:
|
||||
raise AccountNotFoundError()
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
|
||||
|
||||
@@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceOauthBinding
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
from fields.document_fields import document_status_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from models.model import ApiToken, UploadFile
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
|
||||
@@ -617,6 +617,9 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
@@ -653,6 +656,9 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.CHROMA
|
||||
| VectorType.TENCENT
|
||||
| VectorType.PGVECTO_RS
|
||||
| VectorType.BAIDU
|
||||
| VectorType.VIKINGDB
|
||||
| VectorType.UPSTASH
|
||||
):
|
||||
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
|
||||
case (
|
||||
|
||||
@@ -46,8 +46,7 @@ from fields.document_fields import (
|
||||
document_with_segments_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
@@ -24,7 +24,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DocumentSegment
|
||||
from models import DocumentSegment
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from constants import DOCUMENT_EXTENSIONS
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import (
|
||||
FileTooLargeError,
|
||||
@@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
|
||||
)
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
@@ -26,13 +30,12 @@ class FileApi(Resource):
|
||||
@account_initialization_required
|
||||
@marshal_with(upload_config_fields)
|
||||
def get(self):
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT
|
||||
batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT
|
||||
image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT
|
||||
return {
|
||||
"file_size_limit": file_size_limit,
|
||||
"batch_count_limit": batch_count_limit,
|
||||
"image_file_size_limit": image_file_size_limit,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@@ -44,6 +47,10 @@ class FileApi(Resource):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@@ -51,7 +58,7 @@ class FileApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
upload_file = FileService.upload_file(file=file, user=current_user, source=source)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
@@ -75,11 +82,24 @@ class FileSupportTypeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
return {"allowed_extensions": allowed_extensions}
|
||||
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||
|
||||
|
||||
class RemoteFileInfoApi(Resource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
|
||||
@@ -1,88 +1,24 @@
|
||||
import logging
|
||||
from flask_restful import Resource
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class HitTestingApi(Resource):
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
|
||||
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.hit_testing_fields import hit_testing_record_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = reqparse.RequestParser()
|
||||
|
||||
parser.add_argument("query", type=str, location="json")
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
try:
|
||||
response = HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=args["query"],
|
||||
account=current_user,
|
||||
retrieval_model=args["retrieval_model"],
|
||||
external_retrieval_model=args["external_retrieval_model"],
|
||||
limit=10,
|
||||
)
|
||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||
except services.errors.index.IndexNotInitializedError:
|
||||
raise DatasetNotInitializedError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
"in the Settings -> Model Provider."
|
||||
)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
logging.exception("Hit testing failed.")
|
||||
raise InternalServerError(str(e))
|
||||
@@ -38,3 +38,27 @@ class AlreadyActivateError(BaseHTTPException):
|
||||
error_code = "already_activate"
|
||||
description = "Auth Token is invalid or account already activated, please check again."
|
||||
code = 403
|
||||
|
||||
|
||||
class NotAllowedCreateWorkspace(BaseHTTPException):
|
||||
error_code = "not_allowed_create_workspace"
|
||||
description = "Workspace not found, please contact system admin to invite you to join in a workspace."
|
||||
code = 400
|
||||
|
||||
|
||||
class AccountBannedError(BaseHTTPException):
|
||||
error_code = "account_banned"
|
||||
description = "Account is banned."
|
||||
code = 400
|
||||
|
||||
|
||||
class NotAllowedRegister(BaseHTTPException):
|
||||
error_code = "unauthorized"
|
||||
description = "Account not found."
|
||||
code = 400
|
||||
|
||||
|
||||
class EmailSendIpLimitError(BaseHTTPException):
|
||||
error_code = "email_send_ip_limit"
|
||||
description = "Too many emails have been sent from this IP address recently. Please try again later."
|
||||
code = 429
|
||||
|
||||
@@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.login import login_required
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ message_fields = {
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.model import InstalledApp
|
||||
from models import InstalledApp
|
||||
|
||||
|
||||
def installed_app_required(view=None):
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.helper import StrLen, email, extract_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@@ -46,7 +46,7 @@ class SetupApi(Resource):
|
||||
|
||||
# setup
|
||||
RegisterService.setup(
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request)
|
||||
email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request)
|
||||
)
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@@ -20,7 +20,7 @@ from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.helper import TimestampField, timezone
|
||||
from libs.login import login_required
|
||||
from models.account import AccountIntegrate, InvitationCode
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
|
||||
@@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||
args = reqparser.parse_args()
|
||||
|
||||
return WorkflowToolManageService.create_workflow_tool(
|
||||
user_id,
|
||||
tenant_id,
|
||||
args["workflow_app_id"],
|
||||
args["name"],
|
||||
args["label"],
|
||||
args["icon"],
|
||||
args["description"],
|
||||
args["parameters"],
|
||||
args["privacy_policy"],
|
||||
args.get("labels", []),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_app_id=args["workflow_app_id"],
|
||||
name=args["name"],
|
||||
label=args["label"],
|
||||
icon=args["icon"],
|
||||
description=args["description"],
|
||||
parameters=args["parameters"],
|
||||
privacy_policy=args["privacy_policy"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, current_user, True)
|
||||
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask import Response, request
|
||||
from flask_restful import Resource
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
@@ -10,6 +10,10 @@ from services.file_service import FileService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
@@ -21,13 +25,57 @@ class ImagePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
sign=sign,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
class FilePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
||||
parser.add_argument("sign", type=str, required=True, location="args")
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args["timestamp"] or not args["nonce"] or not args["sign"]:
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
response = Response(
|
||||
generator,
|
||||
mimetype=upload_file.mime_type,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
if upload_file.size > 0:
|
||||
response.headers["Content-Length"] = str(upload_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={upload_file.name}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class WorkspaceWebappLogoApi(Resource):
|
||||
def get(self, workspace_id):
|
||||
workspace_id = str(workspace_id)
|
||||
@@ -49,4 +97,5 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
|
||||
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||
|
||||
@@ -16,6 +16,7 @@ class ToolFilePreviewApi(Resource):
|
||||
parser.add_argument("timestamp", type=str, required=True, location="args")
|
||||
parser.add_argument("nonce", type=str, required=True, location="args")
|
||||
parser.add_argument("sign", type=str, required=True, location="args")
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -28,18 +29,27 @@ class ToolFilePreviewApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
result = ToolFileManager.get_file_generator_by_tool_file_id(
|
||||
stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
if not result:
|
||||
if not stream or not tool_file:
|
||||
raise NotFound("file is not found")
|
||||
|
||||
generator, mimetype = result
|
||||
except Exception:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
response = Response(
|
||||
stream,
|
||||
mimetype=tool_file.mimetype,
|
||||
direct_passthrough=True,
|
||||
headers={},
|
||||
)
|
||||
if tool_file.size > 0:
|
||||
response.headers["Content-Length"] = str(tool_file.size)
|
||||
if args["as_attachment"]:
|
||||
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
api.add_resource(ToolFilePreviewApi, "/files/tools/<uuid:file_id>.<string:extension>")
|
||||
|
||||
@@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
|
||||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import index
|
||||
from .app import app, audio, completion, conversation, file, message, workflow
|
||||
from .dataset import dataset, document, segment
|
||||
from .dataset import dataset, document, hit_testing, segment
|
||||
|
||||
@@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from constants import UUID_NIL
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -108,7 +107,6 @@ class ChatApi(Resource):
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class MessageListApi(Resource):
|
||||
"tool_input": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"observation": fields.String,
|
||||
"message_files": fields.List(fields.String, attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
@@ -58,7 +58,7 @@ class MessageListApi(Resource):
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
|
||||
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||
args = self.parse_args()
|
||||
self.hit_testing_args_check(args)
|
||||
|
||||
return self.perform_hit_testing(dataset, args)
|
||||
|
||||
|
||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||
@@ -1,11 +1,14 @@
|
||||
import urllib.parse
|
||||
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from flask_restful import marshal_with, reqparse
|
||||
|
||||
import services
|
||||
from controllers.web import api
|
||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import file_fields
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields, remote_file_info_fields
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@@ -15,6 +18,10 @@ class FileApi(WebApiResource):
|
||||
# get file from request
|
||||
file = request.files["file"]
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("source", type=str, required=False, location="args")
|
||||
source = parser.parse_args().get("source")
|
||||
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@@ -22,7 +29,7 @@ class FileApi(WebApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
upload_file = FileService.upload_file(file=file, user=end_user, source=source)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
@@ -31,4 +38,19 @@ class FileApi(WebApiResource):
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
class RemoteFileInfoApi(WebApiResource):
|
||||
@marshal_with(remote_file_info_fields)
|
||||
def get(self, url):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
try:
|
||||
response = ssrf_proxy.head(decoded_url)
|
||||
return {
|
||||
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||
"file_length": int(response.headers.get("Content-Length", -1)),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from fields.message_fields import agent_thought_fields
|
||||
from fields.raws import FilesContainedField
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import AppMode
|
||||
@@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
|
||||
"id": fields.String,
|
||||
"conversation_id": fields.String,
|
||||
"parent_message_id": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"inputs": FilesContainedField,
|
||||
"query": fields.String,
|
||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||
"created_at": TimestampField,
|
||||
|
||||
@@ -17,7 +17,7 @@ message_fields = {
|
||||
"inputs": fields.Raw,
|
||||
"query": fields.String,
|
||||
"answer": fields.String,
|
||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
||||
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
@@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.file import file_manager
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
|
||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
from factories import file_factory
|
||||
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param app_config: app generate entity
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param memory: memory
|
||||
:param prompt_messages: prompt messages
|
||||
:param variables_pool: variables pool
|
||||
:param db_variables: db variables
|
||||
:param model_instance: model instance
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
@@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
@@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
@@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
|
||||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
message_file_parser = MessageFileParser(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
)
|
||||
|
||||
files = message.message_files
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
@@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@@ -7,10 +7,15 @@ from typing import Any, Optional, Union
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.file import file_manager
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
@@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
||||
@@ -62,6 +62,8 @@ class CotAgentOutputParser:
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
last_character = ""
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
@@ -74,35 +76,38 @@ class CotAgentOutputParser:
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
yield_delta = False
|
||||
|
||||
if delta == "`":
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
last_character = delta
|
||||
yield code_block_cache
|
||||
code_block_cache = ""
|
||||
else:
|
||||
last_character = delta
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == action_str[action_idx] and action_idx > 0:
|
||||
last_character = delta
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
@@ -112,24 +117,25 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
last_character = delta
|
||||
yield action_cache
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in {"\n", " ", ""}:
|
||||
yield_delta = True
|
||||
else:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
elif delta.lower() == thought_str[thought_idx] and thought_idx > 0:
|
||||
last_character = delta
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
@@ -139,12 +145,20 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
last_character = delta
|
||||
yield thought_cache
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if yield_delta:
|
||||
index += steps
|
||||
last_character = delta
|
||||
yield delta
|
||||
continue
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
last_character = delta
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ""
|
||||
|
||||
@@ -156,8 +170,10 @@ class CotAgentOutputParser:
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
elif delta == "}":
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
@@ -168,16 +184,19 @@ class CotAgentOutputParser:
|
||||
continue
|
||||
else:
|
||||
if in_json:
|
||||
last_character = delta
|
||||
json_cache += delta
|
||||
|
||||
if got_json:
|
||||
got_json = False
|
||||
last_character = delta
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
last_character = delta
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
@@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models import AppMode
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class ModelConfigEntity(BaseModel):
|
||||
@@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||
FILE = "file"
|
||||
FILE_LIST = "file-list"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
@@ -102,13 +105,24 @@ class VariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
description: str = ""
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
def convert_none_description(cls, v: Any) -> str:
|
||||
return v or ""
|
||||
|
||||
@field_validator("options", mode="before")
|
||||
@classmethod
|
||||
def convert_none_options(cls, v: Any) -> Sequence[str]:
|
||||
return v or []
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
@@ -136,7 +150,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.file.file_obj import FileExtraConfig
|
||||
from core.file.models import FileExtraConfig
|
||||
from models import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -15,19 +16,21 @@ class FileUploadConfigManager:
|
||||
"""
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
if file_upload_dict.get("enabled"):
|
||||
transform_methods = file_upload_dict.get("allowed_file_upload_methods") or file_upload_dict.get(
|
||||
"allowed_upload_methods", []
|
||||
)
|
||||
data = {
|
||||
"image_config": {
|
||||
"number_limits": file_upload_dict["number_limits"],
|
||||
"transfer_methods": transform_methods,
|
||||
}
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
if is_vision:
|
||||
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
return FileExtraConfig.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
||||
@@ -39,29 +42,7 @@ class FileUploadConfigManager:
|
||||
"""
|
||||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
|
||||
if not isinstance(config["file_upload"], dict):
|
||||
raise ValueError("file_upload must be of dict type")
|
||||
|
||||
# check image config
|
||||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in {"high", "low"}:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in {"remote_url", "local_file"}:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
||||
@@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
|
||||
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
variables.append(VariableEntity(**variable))
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@@ -10,6 +10,7 @@ from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
import contexts
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
@@ -20,11 +21,12 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
@@ -95,10 +97,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -106,8 +114,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||
)
|
||||
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
# always enable retriever resource in debugger mode
|
||||
@@ -119,10 +128,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
@@ -215,13 +226,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
(conversation, message) = self._init_generate_records(
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
override_model_configs=workflow.features_dict if is_first_conversation else None,
|
||||
)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
@@ -44,12 +41,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
@@ -57,10 +48,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
@@ -81,7 +68,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
@@ -115,6 +102,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
):
|
||||
return
|
||||
|
||||
# trace start time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
@@ -142,6 +132,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# trace end time
|
||||
end_time = time.perf_counter()
|
||||
print(f"conversation_dialogue_count time: {end_time - start_time}")
|
||||
|
||||
# trace start time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
@@ -165,6 +162,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# trace end time
|
||||
end_time = time.perf_counter()
|
||||
print(f"init graph time: {end_time - start_time}")
|
||||
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
@@ -201,15 +202,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
@@ -229,14 +221,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
@@ -258,8 +242,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
@@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAdvancedChatMessageEndEvent,
|
||||
@@ -50,12 +51,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.enums import CreatedByRole
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -72,6 +76,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,8 +120,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
self._recorded_files: list[Mapping[str, Any]] = []
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
@@ -295,6 +302,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
# Record files if it's an answer node or end node
|
||||
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -361,7 +372,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
outputs=event.outputs,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
@@ -487,10 +498,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._refetch_message()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
@@ -498,6 +505,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=self._message.id,
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatedByRole.ACCOUNT
|
||||
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else CreatedByRole.END_USER,
|
||||
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
|
||||
)
|
||||
for file in self._recorded_files
|
||||
]
|
||||
db.session.add_all(message_files)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
@@ -537,7 +560,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
@@ -17,12 +18,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,7 +50,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -97,12 +103,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -115,8 +128,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AgentChatAppGenerateEntity(
|
||||
@@ -124,10 +136,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -1,35 +1,92 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileExtraConfig
|
||||
from factories import file_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
||||
def _prepare_user_inputs(
|
||||
self,
|
||||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
app_config: "AppConfig",
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
# Convert single file to File
|
||||
files_inputs = {
|
||||
k: file_factory.build_from_mapping(
|
||||
mapping=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||
}
|
||||
# Convert list of files to File
|
||||
file_list_inputs = {
|
||||
k: file_factory.build_from_mappings(
|
||||
mappings=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
if isinstance(v, list)
|
||||
# Ensure skip List<File>
|
||||
and all(isinstance(item, dict) for item in v)
|
||||
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
|
||||
}
|
||||
# Merge all inputs
|
||||
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ""
|
||||
if (
|
||||
var.type
|
||||
in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
}
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
# Check if all files are converted to File
|
||||
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||
raise ValueError("Invalid input type")
|
||||
if any(
|
||||
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||
):
|
||||
raise ValueError("Invalid input type")
|
||||
|
||||
return user_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
else:
|
||||
return None
|
||||
|
||||
if var.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
@@ -41,12 +98,24 @@ class BaseAppGenerator:
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
options = var.options
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
elif var.type == VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
elif var.type == VariableEntityType.FILE_LIST:
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
@@ -32,10 +31,10 @@ class AppQueueManager:
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
# user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
# redis_client.setex(
|
||||
# AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
# )
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
@@ -114,26 +113,27 @@ class AppQueueManager:
|
||||
Set task stop flag
|
||||
:return:
|
||||
"""
|
||||
result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
if result is None:
|
||||
return
|
||||
return
|
||||
# result = redis_client.get(cls._generate_task_belong_cache_key(task_id))
|
||||
# if result is None:
|
||||
# return
|
||||
|
||||
user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
# user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user"
|
||||
# if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||
# return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
redis_client.setex(stopped_cache_key, 600, 1)
|
||||
# stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
# redis_client.setex(stopped_cache_key, 600, 1)
|
||||
|
||||
def _is_stopped(self) -> bool:
|
||||
"""
|
||||
Check if task is stopped
|
||||
:return:
|
||||
"""
|
||||
stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
result = redis_client.get(stopped_cache_key)
|
||||
if result is not None:
|
||||
return True
|
||||
# stopped_cache_key = AppQueueManager._generate_stopped_cache_key(self._task_id)
|
||||
# result = redis_client.get(stopped_cache_key)
|
||||
# if result is not None:
|
||||
# return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||
from models.model import App, AppMode, Message, MessageAnnotation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
|
||||
|
||||
class AppRunner:
|
||||
@@ -37,7 +37,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -137,7 +137,7 @@ class AppRunner:
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
files: list["File"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
@@ -17,11 +18,12 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -99,12 +101,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -117,7 +126,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = ChatAppGenerateEntity(
|
||||
@@ -125,15 +134,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id"),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
|
||||
@@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser, Message
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Message
|
||||
from models.enums import CreatedByRole
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
@@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
@@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
@@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from sqlalchemy import and_
|
||||
|
||||
@@ -26,7 +27,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||
@@ -136,6 +138,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_model_configs: Optional[Mapping[str, Any]] = None,
|
||||
) -> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
@@ -157,14 +160,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
app_model_config_id = None
|
||||
override_model_configs = None
|
||||
model_provider = None
|
||||
model_id = None
|
||||
else:
|
||||
app_model_config_id = app_config.app_model_config_id
|
||||
model_provider = application_generate_entity.model_conf.provider
|
||||
model_id = application_generate_entity.model_conf.model
|
||||
override_model_configs = None
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in {
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT,
|
||||
@@ -176,72 +177,74 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
introduction = self._get_conversation_introduction(application_generate_entity)
|
||||
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=app_model_config_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name="New conversation",
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
with db.Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
conversation = Conversation()
|
||||
conversation.id = str(uuid.uuid4())
|
||||
conversation.app_id = app_config.app_id
|
||||
conversation.app_model_config_id = app_model_config_id
|
||||
conversation.model_provider = model_provider
|
||||
conversation.model_id = model_id
|
||||
conversation.override_model_configs = (
|
||||
json.dumps(override_model_configs) if override_model_configs else None
|
||||
)
|
||||
conversation.mode = app_config.app_mode.value
|
||||
conversation.name = "New conversation"
|
||||
conversation.inputs = application_generate_entity.inputs
|
||||
conversation.introduction = introduction
|
||||
conversation.system_instruction = ""
|
||||
conversation.system_instruction_tokens = 0
|
||||
conversation.status = "normal"
|
||||
conversation.invoke_from = application_generate_entity.invoke_from.value
|
||||
conversation.from_source = from_source
|
||||
conversation.from_end_user_id = end_user_id
|
||||
conversation.from_account_id = account_id
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
session.refresh(conversation)
|
||||
else:
|
||||
conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
message = Message(
|
||||
app_id=app_config.app_id,
|
||||
model_provider=model_provider,
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
conversation_id=conversation.id,
|
||||
inputs=application_generate_entity.inputs,
|
||||
query=application_generate_entity.query or "",
|
||||
message="",
|
||||
message_tokens=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
answer="",
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
parent_message_id=getattr(application_generate_entity, "parent_message_id", None),
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id,
|
||||
)
|
||||
with db.Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
message = Message()
|
||||
message.app_id = app_config.app_id
|
||||
message.model_provider = model_provider
|
||||
message.model_id = model_id
|
||||
message.override_model_configs = json.dumps(override_model_configs) if override_model_configs else None
|
||||
message.conversation_id = conversation.id
|
||||
message.inputs = application_generate_entity.inputs
|
||||
message.query = application_generate_entity.query or ""
|
||||
message.message = ""
|
||||
message.message_tokens = 0
|
||||
message.message_unit_price = 0
|
||||
message.answer = ""
|
||||
message.answer_tokens = 0
|
||||
message.answer_unit_price = 0
|
||||
message.answer_price_unit = 0
|
||||
message.parent_message_id = getattr(application_generate_entity, "parent_message_id", None)
|
||||
message.provider_response_latency = 0
|
||||
message.total_price = 0
|
||||
message.currency = "USD"
|
||||
message.invoke_from = application_generate_entity.invoke_from.value
|
||||
message.from_source = from_source
|
||||
message.from_end_user_id = end_user_id
|
||||
message.from_account_id = account_id
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db.session.refresh(message)
|
||||
session.add(message)
|
||||
session.commit()
|
||||
session.refresh(message)
|
||||
|
||||
for file in application_generate_entity.files:
|
||||
message_file = MessageFile(
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
url=file.url,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by=account_id or end_user_id,
|
||||
created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER),
|
||||
created_by=account_id or end_user_id or "",
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args["inputs"]
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
||||
trace_manager = TraceQueueManager(
|
||||
app_id=app_model.id,
|
||||
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||
)
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
||||
files=file_objs,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -52,6 +51,7 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
@@ -69,6 +69,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -103,6 +104,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._wip_workflow_node_executions = {}
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -331,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs)
|
||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
||||
else None,
|
||||
outputs=event.outputs,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
@@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
@@ -41,9 +40,9 @@ from core.workflow.graph_engine.entities.event import (
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.iteration import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
@@ -137,9 +136,8 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
node_type = NodeType(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_type_classes_mapping[node_type]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||
|
||||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
@@ -22,7 +23,7 @@ class InvokeFrom(Enum):
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
def value_of(cls, value: str):
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -81,7 +82,7 @@ class AppGenerateEntity(BaseModel):
|
||||
app_config: AppConfig
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: list[FileVar] = []
|
||||
files: Sequence[File]
|
||||
user_id: str
|
||||
|
||||
# extras
|
||||
@@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Base entity for conversation-based app generation.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||
"For service API, we need to ensure its forward compatibility, "
|
||||
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("parent_message_id")
|
||||
@classmethod
|
||||
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||
return v
|
||||
|
||||
|
||||
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
pass
|
||||
|
||||
|
||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
@@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
pass
|
||||
|
||||
|
||||
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
@@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
parent_message_id: Optional[str] = None
|
||||
workflow_run_id: Optional[str] = None
|
||||
query: str
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = {}
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = None
|
||||
|
||||
|
||||
class MessageFileStreamResponse(StreamResponse):
|
||||
@@ -211,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
created_by: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
|
||||
workflow_run_id: str
|
||||
@@ -296,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import re
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
return SegmentGroup(value=segments)
|
||||
@@ -53,7 +53,7 @@ class BasedGenerateTaskPipeline:
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None):
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
@@ -100,7 +100,7 @@ class BasedGenerateTaskPipeline:
|
||||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
|
||||
def _error_to_stream_response(self, e: Exception):
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@@ -83,7 +85,9 @@ class MessageCycleManage:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"generate conversation name failed: {e}")
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
@@ -27,27 +30,26 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file.file_obj import FileVar
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,6 +59,7 @@ class WorkflowCycleManage:
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
@@ -116,7 +119,7 @@ class WorkflowCycleManage:
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
@@ -132,8 +135,10 @@ class WorkflowCycleManage:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
@@ -229,28 +234,30 @@ class WorkflowCycleManage:
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
session.commit()
|
||||
session.refresh(workflow_node_execution)
|
||||
|
||||
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
@@ -262,21 +269,39 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@@ -289,19 +314,36 @@ class WorkflowCycleManage:
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
}
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@@ -603,7 +645,7 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -612,15 +654,15 @@ class WorkflowCycleManage:
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for output_var, output_value in outputs_dict.items():
|
||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
||||
if file_vars:
|
||||
files.extend(file_vars)
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
files = [file for sublist in files for file in sublist]
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
@@ -632,17 +674,17 @@ class WorkflowCycleManage:
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = self._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
file = self._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file_var = self._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
file = self._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
@@ -651,14 +693,11 @@ class WorkflowCycleManage:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
@@ -678,17 +717,7 @@ class WorkflowCycleManage:
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any = None
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
@@ -0,0 +1,19 @@
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .models import (
|
||||
File,
|
||||
FileExtraConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FileType",
|
||||
"FileExtraConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
"ImageConfig",
|
||||
"FileAttribute",
|
||||
"ArrayFileAttribute",
|
||||
"FILE_MODEL_IDENTITY",
|
||||
]
|
||||
|
||||
1
api/core/file/constants.py
Normal file
1
api/core/file/constants.py
Normal file
@@ -0,0 +1 @@
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user