mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 01:33:59 +00:00
Compare commits
4 Commits
1.5.0
...
v141-hotfi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
151d68b5b3 | ||
|
|
9d38441c8d | ||
|
|
7ab0b7346e | ||
|
|
861cdbd8b6 |
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
npm add -g pnpm@10.11.1
|
||||
npm add -g pnpm@10.8.0
|
||||
cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
|
||||
2
.github/actions/setup-uv/action.yml
vendored
2
.github/actions/setup-uv/action.yml
vendored
@@ -8,7 +8,7 @@ inputs:
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '~=0.7.11'
|
||||
default: '0.6.14'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
|
||||
20
.github/pull_request_template.md
vendored
20
.github/pull_request_template.md
vendored
@@ -1,23 +1,25 @@
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> 1. Make sure you have read our [contribution guidelines](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)
|
||||
> 2. Ensure there is an associated issue and you have been assigned to it
|
||||
> 3. Use the correct syntax to link this PR: `Fixes #<issue number>`.
|
||||
# Summary
|
||||
|
||||
## Summary
|
||||
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
|
||||
|
||||
<!-- Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. -->
|
||||
> [!Tip]
|
||||
> Close issue syntax: `Fixes #<issue number>` or `Resolves #<issue number>`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details.
|
||||
|
||||
## Screenshots
|
||||
|
||||
# Screenshots
|
||||
|
||||
| Before | After |
|
||||
|--------|-------|
|
||||
| ... | ... |
|
||||
|
||||
## Checklist
|
||||
# Checklist
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Please review the checklist below before submitting your pull request.
|
||||
|
||||
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
|
||||
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
|
||||
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
|
||||
- [x] I've updated the documentation accordingly.
|
||||
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
|
||||
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@@ -83,15 +83,9 @@ jobs:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
services: |
|
||||
db
|
||||
redis
|
||||
sandbox
|
||||
ssrf_proxy
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Run Workflow
|
||||
run: uv run --project api bash dev/pytest/pytest_workflow.sh
|
||||
|
||||
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "v141-hotfix"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
28
.github/workflows/deploy-rag-dev.yml
vendored
28
.github/workflows/deploy-rag-dev.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Deploy RAG Dev
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}
|
||||
1
.github/workflows/expose_service_ports.sh
vendored
1
.github/workflows/expose_service_ports.sh
vendored
@@ -10,7 +10,6 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com
|
||||
yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml
|
||||
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
|
||||
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
|
||||
|
||||
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
|
||||
|
||||
18
.github/workflows/vdb-tests.yml
vendored
18
.github/workflows/vdb-tests.yml
vendored
@@ -31,13 +31,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: ./.github/actions/setup-uv
|
||||
with:
|
||||
@@ -66,7 +59,7 @@ jobs:
|
||||
tidb
|
||||
tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
@@ -82,15 +75,8 @@ jobs:
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
- name: Check VDB Ready (TiDB)
|
||||
- name: Check TiDB Ready
|
||||
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -179,7 +179,6 @@ docker/volumes/pgvecto_rs/data/*
|
||||
docker/volumes/couchbase/*
|
||||
docker/volumes/oceanbase/*
|
||||
docker/volumes/plugin_daemon/*
|
||||
docker/volumes/matrixone/*
|
||||
!docker/volumes/oceanbase/init.d
|
||||
|
||||
docker/nginx/conf.d/default.conf
|
||||
@@ -193,12 +192,12 @@ sdks/python-client/dist
|
||||
sdks/python-client/dify_client.egg-info
|
||||
|
||||
.vscode/*
|
||||
!.vscode/launch.json.template
|
||||
!.vscode/README.md
|
||||
!.vscode/launch.json
|
||||
pyrightconfig.json
|
||||
api/.vscode
|
||||
|
||||
.idea/
|
||||
.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
@@ -208,9 +207,3 @@ plugins.jsonl
|
||||
|
||||
# mise
|
||||
mise.toml
|
||||
|
||||
# Next.js build output
|
||||
.next/
|
||||
|
||||
# AI Assistant
|
||||
.roo/
|
||||
|
||||
14
.vscode/README.md
vendored
14
.vscode/README.md
vendored
@@ -1,14 +0,0 @@
|
||||
# Debugging with VS Code
|
||||
|
||||
This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory.
|
||||
|
||||
## How to Use
|
||||
|
||||
1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory.
|
||||
2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file.
|
||||
3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D).
|
||||
4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button.
|
||||
|
||||
## Tips
|
||||
|
||||
- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging.
|
||||
68
.vscode/launch.json.template
vendored
68
.vscode/launch.json.template
vendored
@@ -1,68 +0,0 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Flask API",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "flask",
|
||||
"env": {
|
||||
"FLASK_APP": "app.py",
|
||||
"FLASK_ENV": "development",
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"run",
|
||||
"--host=0.0.0.0",
|
||||
"--port=5001",
|
||||
"--no-debugger",
|
||||
"--no-reload"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Python: Celery Worker (Solo)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"env": {
|
||||
"GEVENT_SUPPORT": "True"
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"app.celery",
|
||||
"worker",
|
||||
"-P",
|
||||
"solo",
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,generation,mail,ops_trace",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/api",
|
||||
"python": "${workspaceFolder}/api/.venv/bin/python"
|
||||
},
|
||||
{
|
||||
"name": "Next.js: debug full stack",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/web/node_modules/next/dist/bin/next",
|
||||
"runtimeArgs": ["--inspect"],
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"serverReadyAction": {
|
||||
"action": "debugWithChrome",
|
||||
"killOnServerStop": true,
|
||||
"pattern": "- Local:.+(https?://.+)",
|
||||
"uriFormat": "%s",
|
||||
"webRoot": "${workspaceFolder}/web"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/web"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -226,15 +226,6 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Using Alibaba Cloud Computing Nest
|
||||
|
||||
Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Using Alibaba Cloud Data Management
|
||||
|
||||
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@@ -209,14 +209,6 @@ docker compose up -d
|
||||
|
||||
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### استخدام Alibaba Cloud للنشر
|
||||
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### استخدام Alibaba Cloud Data Management للنشر
|
||||
|
||||
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## المساهمة
|
||||
|
||||
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.
|
||||
|
||||
@@ -225,15 +225,6 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
|
||||
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud ব্যবহার করে ডিপ্লয়
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয়
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)।
|
||||
|
||||
@@ -221,15 +221,6 @@ docker compose up -d
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### 使用 阿里云计算巢 部署
|
||||
|
||||
使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云
|
||||
|
||||
#### 使用 阿里云数据管理DMS 部署
|
||||
|
||||
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
|
||||
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
@@ -221,15 +221,6 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
|
||||
|
||||
@@ -221,15 +221,6 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuir
|
||||
|
||||
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@@ -219,15 +219,6 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuer
|
||||
|
||||
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
|
||||
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。
|
||||
|
||||
- **Dify Community Editionのセルフホスティング</br>**
|
||||
この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。
|
||||
この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。
|
||||
詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。
|
||||
|
||||
- **企業/組織向けのDify</br>**
|
||||
@@ -220,13 +220,6 @@ docker compose up -d
|
||||
##### AWS
|
||||
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
|
||||
|
||||
|
||||
## 貢献
|
||||
|
||||
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
|
||||
|
||||
@@ -219,15 +219,6 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
|
||||
##### AWS
|
||||
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@@ -213,15 +213,6 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
|
||||
##### AWS
|
||||
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
|
||||
|
||||
|
||||
## 기여
|
||||
|
||||
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
|
||||
|
||||
@@ -218,15 +218,6 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Contribuindo
|
||||
|
||||
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
|
||||
@@ -219,15 +219,6 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Prispevam
|
||||
|
||||
Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.
|
||||
|
||||
@@ -212,15 +212,6 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
|
||||
##### AWS
|
||||
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
|
||||
|
||||
|
||||
## Katkıda Bulunma
|
||||
|
||||
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
|
||||
|
||||
@@ -224,15 +224,6 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
|
||||
|
||||
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
#### 使用 阿里云计算巢進行部署
|
||||
|
||||
[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### 使用 阿里雲數據管理DMS 進行部署
|
||||
|
||||
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
|
||||
|
||||
|
||||
## 貢獻
|
||||
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
|
||||
10
README_VI.md
10
README_VI.md
@@ -214,16 +214,6 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
|
||||
##### AWS
|
||||
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
|
||||
#### Alibaba Cloud
|
||||
|
||||
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
|
||||
|
||||
#### Alibaba Cloud Data Management
|
||||
|
||||
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
|
||||
|
||||
|
||||
## Đóng góp
|
||||
|
||||
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
|
||||
|
||||
@@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
|
||||
# Vector database configuration
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
|
||||
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore
|
||||
VECTOR_STORE=weaviate
|
||||
|
||||
# Weaviate configuration
|
||||
@@ -294,13 +294,6 @@ VIKINGDB_SCHEMA=http
|
||||
VIKINGDB_CONNECTION_TIMEOUT=30
|
||||
VIKINGDB_SOCKET_TIMEOUT=30
|
||||
|
||||
# Matrixone configration
|
||||
MATRIXONE_HOST=127.0.0.1
|
||||
MATRIXONE_PORT=6001
|
||||
MATRIXONE_USER=dump
|
||||
MATRIXONE_PASSWORD=111
|
||||
MATRIXONE_DATABASE=dify
|
||||
|
||||
# Lindorm configuration
|
||||
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
|
||||
LINDORM_USERNAME=admin
|
||||
@@ -339,11 +332,9 @@ PROMPT_GENERATION_MAX_TOKENS=512
|
||||
CODE_GENERATION_MAX_TOKENS=1024
|
||||
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
|
||||
|
||||
# Mail configuration, support: resend, smtp, sendgrid
|
||||
# Mail configuration, support: resend, smtp
|
||||
MAIL_TYPE=
|
||||
# If using SendGrid, use the 'from' field for authentication if necessary.
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
# resend configuration
|
||||
RESEND_API_KEY=
|
||||
RESEND_API_URL=https://api.resend.com
|
||||
# smtp configuration
|
||||
@@ -353,8 +344,7 @@ SMTP_USERNAME=123
|
||||
SMTP_PASSWORD=abc
|
||||
SMTP_USE_TLS=true
|
||||
SMTP_OPPORTUNISTIC_TLS=false
|
||||
# Sendgid configuration
|
||||
SENDGRID_API_KEY=
|
||||
|
||||
# Sentry configuration
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -501,10 +491,3 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000
|
||||
|
||||
# Prevent Clickjacking
|
||||
ALLOW_EMBED=false
|
||||
|
||||
# Dataset queue monitor configuration
|
||||
QUEUE_MONITOR_THRESHOLD=200
|
||||
# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai
|
||||
QUEUE_MONITOR_ALERT_EMAILS=
|
||||
# Monitor interval in minutes, default is 30 minutes
|
||||
QUEUE_MONITOR_INTERVAL=30
|
||||
|
||||
104
api/.ruff.toml
104
api/.ruff.toml
@@ -1,4 +1,6 @@
|
||||
exclude = ["migrations/*"]
|
||||
exclude = [
|
||||
"migrations/*",
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[format]
|
||||
@@ -7,14 +9,14 @@ quote-style = "double"
|
||||
[lint]
|
||||
preview = false
|
||||
select = [
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"B", # flake8-bugbear rules
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle E rules
|
||||
"F", # pyflakes rules
|
||||
"FURB", # refurb rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"PT", # flake8-pytest-style rules
|
||||
"PLC0208", # iteration-over-set
|
||||
"PLC0414", # useless-import-alias
|
||||
"PLE0604", # invalid-all-object
|
||||
@@ -22,60 +24,58 @@ select = [
|
||||
"PLR0402", # manual-from-import
|
||||
"PLR1711", # useless-return
|
||||
"PLR1714", # repeated-equality-comparison
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
"RUF013", # implicit-optional
|
||||
"RUF019", # unnecessary-key-check
|
||||
"RUF100", # unused-noqa
|
||||
"RUF101", # redirected-noqa
|
||||
"RUF200", # invalid-pyproject-toml
|
||||
"RUF022", # unsorted-dunder-all
|
||||
"S506", # unsafe-yaml-load
|
||||
"SIM", # flake8-simplify rules
|
||||
"TRY400", # error-instead-of-exception
|
||||
"TRY401", # verbose-log-message
|
||||
"UP", # pyupgrade rules
|
||||
"W191", # tab-indentation
|
||||
"W605", # invalid-escape-sequence
|
||||
# security related linting rules
|
||||
# RCE proctection (sort of)
|
||||
"S102", # exec-builtin, disallow use of `exec`
|
||||
"S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval`
|
||||
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
|
||||
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
|
||||
"S311", # suspicious-non-cryptographic-random-usage
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"E402", # module-import-not-at-top-of-file
|
||||
"E711", # none-comparison
|
||||
"E712", # true-false-comparison
|
||||
"E721", # type-comparison
|
||||
"E722", # bare-except
|
||||
"F821", # undefined-name
|
||||
"F841", # unused-variable
|
||||
"FURB113", # repeated-append
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
"B903", # class-as-data-structure
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
"PT011", # pytest-raises-too-broad
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
@@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base
|
||||
WORKDIR /app/api
|
||||
|
||||
# Install uv
|
||||
ENV UV_VERSION=0.7.11
|
||||
ENV UV_VERSION=0.6.14
|
||||
|
||||
RUN pip install --no-cache-dir uv==${UV_VERSION}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||
from models.provider import Provider, ProviderModel
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.account_service import RegisterService, TenantService
|
||||
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
@@ -68,7 +68,6 @@ def reset_password(email, new_password, password_confirm):
|
||||
account.password = base64_password_hashed
|
||||
account.password_salt = base64_salt
|
||||
db.session.commit()
|
||||
AccountService.reset_login_error_rate_limit(email)
|
||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||
|
||||
|
||||
@@ -281,7 +280,6 @@ def migrate_knowledge_vector_database():
|
||||
VectorType.ELASTICSEARCH,
|
||||
VectorType.OPENGAUSS,
|
||||
VectorType.TABLESTORE,
|
||||
VectorType.MATRIXONE,
|
||||
}
|
||||
lower_collection_vector_types = {
|
||||
VectorType.ANALYTICDB,
|
||||
@@ -848,9 +846,6 @@ def clear_orphaned_file_records(force: bool):
|
||||
{"type": "text", "table": "workflow_node_executions", "column": "outputs"},
|
||||
{"type": "text", "table": "conversations", "column": "introduction"},
|
||||
{"type": "text", "table": "conversations", "column": "system_instruction"},
|
||||
{"type": "text", "table": "accounts", "column": "avatar"},
|
||||
{"type": "text", "table": "apps", "column": "icon"},
|
||||
{"type": "text", "table": "sites", "column": "icon"},
|
||||
{"type": "json", "table": "messages", "column": "inputs"},
|
||||
{"type": "json", "table": "messages", "column": "message"},
|
||||
]
|
||||
|
||||
@@ -609,7 +609,7 @@ class MailConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
|
||||
description="Email service provider type ('smtp' or 'resend'), default to None.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@@ -663,11 +663,6 @@ class MailConfig(BaseSettings):
|
||||
default=50,
|
||||
)
|
||||
|
||||
SENDGRID_API_KEY: Optional[str] = Field(
|
||||
description="API key for SendGrid service",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class RagEtlConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Any, Literal, Optional
|
||||
from urllib.parse import parse_qsl, quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
@@ -24,7 +24,6 @@ from .vdb.couchbase_config import CouchbaseConfig
|
||||
from .vdb.elasticsearch_config import ElasticsearchConfig
|
||||
from .vdb.huawei_cloud_config import HuaweiCloudConfig
|
||||
from .vdb.lindorm_config import LindormConfig
|
||||
from .vdb.matrixone_config import MatrixoneConfig
|
||||
from .vdb.milvus_config import MilvusConfig
|
||||
from .vdb.myscale_config import MyScaleConfig
|
||||
from .vdb.oceanbase_config import OceanBaseVectorConfig
|
||||
@@ -257,25 +256,6 @@ class InternalTestConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class DatasetQueueMonitorConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for Dataset Queue Monitor
|
||||
"""
|
||||
|
||||
QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field(
|
||||
description="Threshold for dataset queue monitor",
|
||||
default=200,
|
||||
)
|
||||
QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field(
|
||||
description="Emails for dataset queue monitor alert, separated by commas",
|
||||
default=None,
|
||||
)
|
||||
QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field(
|
||||
description="Interval for dataset queue monitor in minutes",
|
||||
default=30,
|
||||
)
|
||||
|
||||
|
||||
class MiddlewareConfig(
|
||||
# place the configs in alphabet order
|
||||
CeleryConfig,
|
||||
@@ -323,7 +303,5 @@ class MiddlewareConfig(
|
||||
BaiduVectorDBConfig,
|
||||
OpenGaussConfig,
|
||||
TableStoreConfig,
|
||||
DatasetQueueMonitorConfig,
|
||||
MatrixoneConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MatrixoneConfig(BaseModel):
|
||||
"""Matrixone vector database configuration."""
|
||||
|
||||
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
|
||||
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
|
||||
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
|
||||
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
|
||||
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
|
||||
MATRIXONE_METRIC: str = Field(
|
||||
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
|
||||
)
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="1.5.0",
|
||||
default="1.4.1",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
@@ -60,7 +60,8 @@ class NacosHttpClient:
|
||||
sign_str = tenant + "+"
|
||||
if group:
|
||||
sign_str = sign_str + group + "+"
|
||||
sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it.
|
||||
if sign_str:
|
||||
sign_str += ts
|
||||
return sign_str
|
||||
|
||||
def get_access_token(self, force_refresh=False):
|
||||
|
||||
@@ -63,7 +63,6 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
)
|
||||
|
||||
@@ -56,7 +56,8 @@ class InsertExploreAppListApi(Resource):
|
||||
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
with Session(db.engine) as session:
|
||||
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
|
||||
if not app:
|
||||
raise NotFound(f"App '{args['app_id']}' is not found")
|
||||
|
||||
@@ -77,38 +78,38 @@ class InsertExploreAppListApi(Resource):
|
||||
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
app_id=app.id,
|
||||
description=desc,
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
)
|
||||
if not recommended_app:
|
||||
recommended_app = RecommendedApp(
|
||||
app_id=app.id,
|
||||
description=desc,
|
||||
copyright=copy_right,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
language=args["language"],
|
||||
category=args["category"],
|
||||
position=args["position"],
|
||||
)
|
||||
|
||||
db.session.add(recommended_app)
|
||||
db.session.add(recommended_app)
|
||||
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
app.is_public = True
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
return {"result": "success"}, 201
|
||||
else:
|
||||
recommended_app.description = desc
|
||||
recommended_app.copyright = copy_right
|
||||
recommended_app.privacy_policy = privacy_policy
|
||||
recommended_app.custom_disclaimer = custom_disclaimer
|
||||
recommended_app.language = args["language"]
|
||||
recommended_app.category = args["category"]
|
||||
recommended_app.position = args["position"]
|
||||
|
||||
app.is_public = True
|
||||
app.is_public = True
|
||||
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 200
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class InsertExploreAppApi(Resource):
|
||||
|
||||
@@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
return AppAnnotationService.batch_import_app_annotations(app_id, file)
|
||||
|
||||
|
||||
@@ -17,8 +17,6 @@ from libs.login import login_required
|
||||
from models import Account
|
||||
from models.model import App
|
||||
from services.app_dsl_service import AppDslService, ImportStatus
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class AppImportApi(Resource):
|
||||
@@ -62,9 +60,7 @@ class AppImportApi(Resource):
|
||||
app_id=args.get("app_id"),
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
@@ -19,12 +18,10 @@ from controllers.console.app.error import (
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory, variable_factory
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@@ -33,7 +30,6 @@ from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@@ -42,24 +38,6 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
# at the controller level rather than in the workflow logic. This would improve separation
|
||||
# of concerns and make the code more maintainable.
|
||||
def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
|
||||
files = files or []
|
||||
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
file_objs: Sequence[File] = []
|
||||
if file_extra_config is None:
|
||||
return file_objs
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
)
|
||||
return file_objs
|
||||
|
||||
|
||||
class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -424,30 +402,15 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("query", type=str, required=False, location="json", default="")
|
||||
parser.add_argument("files", type=list, location="json", default=[])
|
||||
args = parser.parse_args()
|
||||
|
||||
user_inputs = args.get("inputs")
|
||||
if user_inputs is None:
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
files = _parse_file(draft_workflow, args.get("files"))
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
workflow_node_execution = workflow_service.run_draft_workflow_node(
|
||||
app_model=app_model,
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=args.get("query", ""),
|
||||
files=files,
|
||||
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
@@ -768,27 +731,6 @@ class WorkflowByIdApi(Resource):
|
||||
return None, 204
|
||||
|
||||
|
||||
class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_fields)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
workflow = srv.get_draft_workflow(app_model)
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
node_exec = srv.get_node_last_run(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("last run not found")
|
||||
return node_exec
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
@@ -853,7 +795,3 @@ api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeLastRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
|
||||
)
|
||||
|
||||
@@ -6,12 +6,12 @@ from sqlalchemy.orm import Session
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@@ -34,25 +34,11 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
@@ -71,8 +57,6 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@@ -1,421 +0,0 @@
|
||||
import logging
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import current_user, login_required
|
||||
from models import App, AppMode, db
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
|
||||
if isinstance(value, FileSegment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
return value.value
|
||||
|
||||
|
||||
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
|
||||
value = variable.get_value()
|
||||
# create a copy of the value to avoid affecting the model cache.
|
||||
value = value.model_copy(deep=True)
|
||||
# Refresh the url signature before returning it to client.
|
||||
if isinstance(value, FileSegment):
|
||||
file = value.value
|
||||
file.remote_url = file.generate_url()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
files = value.value
|
||||
for file in files:
|
||||
file.remote_url = file.generate_url()
|
||||
return _convert_values_to_json_serializable_object(value)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
return parser
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||
value=fields.Raw(attribute=_serialize_var_value),
|
||||
)
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||
"id": fields.String,
|
||||
"type": fields.String(attribute=lambda _: "env"),
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||
"value_type": fields.String,
|
||||
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||
"visible": fields.Boolean,
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||
}
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
return var_list.variables
|
||||
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||
"total": fields.Raw(),
|
||||
}
|
||||
|
||||
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||
}
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
|
||||
- Dify has been property setup.
|
||||
- The request user has logged in and initialized.
|
||||
- The requested app is a workflow or a chat flow.
|
||||
- The request user has the edit permission for the app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def wrapper(*args, **kwargs):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||
if not workflow_exist:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=app_model.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
]:
|
||||
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||
#
|
||||
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class NodeVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||
|
||||
return node_vars
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
srv = WorkflowDraftVariableService(db.session())
|
||||
srv.delete_node_variables(app_model.id, node_id)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableApi(Resource):
|
||||
_PATCH_NAME_FIELD = "name"
|
||||
_PATCH_VALUE_FIELD = "value"
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "local_file",
|
||||
# "url": "",
|
||||
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
|
||||
# }
|
||||
#
|
||||
# Remote File:
|
||||
#
|
||||
#
|
||||
# {
|
||||
# "type": "image",
|
||||
# "transfer_method": "remote_url",
|
||||
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
# Parse 'value' field as-is to maintain its original data structure
|
||||
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
if new_name is None and raw_value is None:
|
||||
return variable
|
||||
|
||||
new_value = None
|
||||
if raw_value is not None:
|
||||
if variable.value_type == SegmentType.FILE:
|
||||
if not isinstance(raw_value, dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
|
||||
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
|
||||
elif variable.value_type == SegmentType.ARRAY_FILE:
|
||||
if not isinstance(raw_value, list):
|
||||
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
class VariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
if variable.app_id != app_model.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
if resetted is None:
|
||||
return Response("", 204)
|
||||
else:
|
||||
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||
else:
|
||||
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||
return draft_vars
|
||||
|
||||
|
||||
class ConversationVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
|
||||
# so their IDs can be returned to the caller.
|
||||
workflow_srv = WorkflowService()
|
||||
draft_workflow = workflow_srv.get_draft_workflow(app_model)
|
||||
if draft_workflow is None:
|
||||
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||
def get(self, app_model: App):
|
||||
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class EnvironmentVariableCollectionApi(Resource):
|
||||
@_api_prerequisite
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
if workflow is None:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
env_vars = workflow.environment_variables
|
||||
env_vars_list = []
|
||||
for v in env_vars:
|
||||
env_vars_list.append(
|
||||
{
|
||||
"id": v.id,
|
||||
"type": "env",
|
||||
"name": v.name,
|
||||
"description": v.description,
|
||||
"selector": v.selector,
|
||||
"value_type": v.value_type.value,
|
||||
"value": v.value,
|
||||
# Do not track edited for env vars.
|
||||
"edited": False,
|
||||
"visible": True,
|
||||
"editable": True,
|
||||
}
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
WorkflowVariableCollectionApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||
)
|
||||
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
|
||||
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||
@@ -8,15 +8,6 @@ from libs.login import current_user
|
||||
from models import App, AppMode
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> Optional[App]:
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
@@ -29,7 +20,11 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
|
||||
|
||||
del kwargs["app_id"]
|
||||
|
||||
app_model = _load_app_model(app_id)
|
||||
app_model = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_model:
|
||||
raise AppNotFoundError()
|
||||
|
||||
@@ -119,6 +119,9 @@ class ForgotPasswordResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
|
||||
@@ -686,7 +686,6 @@ class DatasetRetrievalSettingApi(Resource):
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.TENCENT
|
||||
| VectorType.MATRIXONE
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
@@ -734,7 +733,6 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
| VectorType.TABLESTORE
|
||||
| VectorType.TENCENT
|
||||
| VectorType.HUAWEI_CLOUD
|
||||
| VectorType.MATRIXONE
|
||||
):
|
||||
return {
|
||||
"retrieval_method": [
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import cast
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import asc, desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -43,6 +43,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import (
|
||||
dataset_and_document_fields,
|
||||
document_fields,
|
||||
@@ -53,6 +54,8 @@ from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
@@ -239,10 +242,12 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
return response
|
||||
|
||||
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(dataset_and_document_fields)
|
||||
@marshal_with(documents_and_batch_fields)
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
@@ -288,8 +293,6 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
@@ -297,7 +300,7 @@ class DatasetDocumentListApi(Resource):
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
|
||||
return {"dataset": dataset, "documents": documents, "batch": batch}
|
||||
return {"documents": documents, "batch": batch}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -859,16 +862,77 @@ class DocumentStatusApi(DocumentResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
document_ids = request.args.getlist("document_id")
|
||||
for document_id in document_ids:
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
|
||||
except services.errors.document.DocumentIndexingError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
except NotFound as e:
|
||||
raise NotFound(str(e))
|
||||
indexing_cache_key = "document_{}_indexing".format(document.id)
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
|
||||
|
||||
if action == "enable":
|
||||
if document.enabled:
|
||||
continue
|
||||
document.enabled = True
|
||||
document.disabled_at = None
|
||||
document.disabled_by = None
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
elif action == "disable":
|
||||
if not document.completed_at or document.indexing_status != "completed":
|
||||
raise InvalidActionError(f"Document: {document.name} is not completed.")
|
||||
if not document.enabled:
|
||||
continue
|
||||
|
||||
document.enabled = False
|
||||
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.disabled_by = current_user.id
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
elif action == "archive":
|
||||
if document.archived:
|
||||
continue
|
||||
|
||||
document.archived = True
|
||||
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
document.archived_by = current_user.id
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
if document.enabled:
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
remove_document_from_index_task.delay(document_id)
|
||||
|
||||
elif action == "un_archive":
|
||||
if not document.archived:
|
||||
continue
|
||||
document.archived = False
|
||||
document.archived_at = None
|
||||
document.archived_by = None
|
||||
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
# Set cache to prevent indexing the same document multiple times
|
||||
redis_client.setex(indexing_cache_key, 600, 1)
|
||||
|
||||
add_document_to_index_task.delay(document_id)
|
||||
|
||||
else:
|
||||
raise InvalidActionError()
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
# check file type
|
||||
if not file.filename or not file.filename.lower().endswith(".csv"):
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise ValueError("Invalid file type. Only CSV files are allowed")
|
||||
|
||||
try:
|
||||
|
||||
@@ -59,14 +59,7 @@ class InstalledAppsListApi(Resource):
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
user_id = current_user.id
|
||||
res = []
|
||||
app_ids = [installed_app["app"].id for installed_app in installed_app_list]
|
||||
webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids)
|
||||
for installed_app in installed_app_list:
|
||||
webapp_setting = webapp_settings.get(installed_app["app"].id)
|
||||
if not webapp_setting:
|
||||
continue
|
||||
if webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(installed_app["app"].id))
|
||||
if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
@@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider: str, config_id: str):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_role):
|
||||
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
|
||||
raise Forbidden()
|
||||
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
@@ -44,17 +44,6 @@ def only_edition_cloud(view):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
abort(404)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
|
||||
@@ -29,7 +29,7 @@ from core.plugin.entities.request import (
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from libs.helper import length_prefixed_response
|
||||
from libs.helper import compact_generate_response
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
|
||||
@@ -44,7 +44,7 @@ class PluginInvokeLLMApi(Resource):
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeTextEmbeddingApi(Resource):
|
||||
@@ -101,7 +101,7 @@ class PluginInvokeTTSApi(Resource):
|
||||
)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeSpeech2TextApi(Resource):
|
||||
@@ -162,7 +162,7 @@ class PluginInvokeToolApi(Resource):
|
||||
),
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
return compact_generate_response(generator())
|
||||
|
||||
|
||||
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||
@@ -228,7 +228,7 @@ class PluginInvokeAppApi(Resource):
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
class PluginInvokeEncryptApi(Resource):
|
||||
|
||||
@@ -2,14 +2,12 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant
|
||||
from models.model import EndUser
|
||||
from services.account_service import AccountService
|
||||
@@ -32,7 +30,6 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
else:
|
||||
user_model = AccountService.load_user(user_id)
|
||||
if not user_model:
|
||||
@@ -83,12 +80,7 @@ def get_user_tenant(view: Optional[Callable] = None):
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
|
||||
user = get_user(tenant_id, user_id)
|
||||
kwargs["user_model"] = user
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(user) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ from fields.annotation_fields import (
|
||||
annotation_fields,
|
||||
)
|
||||
from libs.login import current_user
|
||||
from models.model import App
|
||||
from models.model import App, EndUser
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action):
|
||||
def post(self, app_model: App, end_user: EndUser, action):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("score_threshold", required=True, type=float, location="json")
|
||||
parser.add_argument("embedding_provider_name", required=True, type=str, location="json")
|
||||
@@ -32,7 +32,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
|
||||
class AnnotationReplyActionStatusApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, job_id, action):
|
||||
def get(self, app_model: App, end_user: EndUser, job_id, action):
|
||||
job_id = str(job_id)
|
||||
app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id))
|
||||
cache_result = redis_client.get(app_annotation_job_key)
|
||||
@@ -50,7 +50,7 @@ class AnnotationReplyActionStatusApi(Resource):
|
||||
|
||||
class AnnotationListApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
keyword = request.args.get("keyword", default="", type=str)
|
||||
@@ -67,7 +67,7 @@ class AnnotationListApi(Resource):
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def post(self, app_model: App):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("question", required=True, type=str, location="json")
|
||||
parser.add_argument("answer", required=True, type=str, location="json")
|
||||
@@ -79,7 +79,7 @@ class AnnotationListApi(Resource):
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@validate_app_token
|
||||
@marshal_with(annotation_fields)
|
||||
def put(self, app_model: App, annotation_id):
|
||||
def put(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@@ -92,7 +92,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
return annotation
|
||||
|
||||
@validate_app_token
|
||||
def delete(self, app_model: App, annotation_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, annotation_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
@@ -47,13 +47,7 @@ class AppInfoApi(Resource):
|
||||
def get(self, app_model: App):
|
||||
"""Get app information"""
|
||||
tags = [tag.name for tag in app_model.tags]
|
||||
return {
|
||||
"name": app_model.name,
|
||||
"description": app_model.description,
|
||||
"tags": tags,
|
||||
"mode": app_model.mode,
|
||||
"author_name": app_model.author_name,
|
||||
}
|
||||
return {"name": app_model.name, "description": app_model.description, "tags": tags, "mode": app_model.mode}
|
||||
|
||||
|
||||
api.add_resource(AppParameterApi, "/parameters")
|
||||
|
||||
@@ -24,13 +24,12 @@ from core.errors.error import (
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
@@ -135,25 +134,11 @@ class WorkflowAppLogApi(Resource):
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("created_at__before", type=str, location="args")
|
||||
parser.add_argument("created_at__after", type=str, location="args")
|
||||
parser.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
@@ -172,8 +157,6 @@ class WorkflowAppLogApi(Resource):
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
created_by_account=args.created_by_account,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
|
||||
@@ -1,25 +1,19 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal, marshal_with, reqparse
|
||||
from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import tag_fields
|
||||
from libs.login import current_user
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
@@ -74,7 +68,6 @@ class DatasetListApi(DatasetApiResource):
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -198,7 +191,6 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return data, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, _, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -299,7 +291,6 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return result_data, 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, _, dataset_id):
|
||||
"""
|
||||
Deletes a dataset given its ID.
|
||||
@@ -329,186 +320,5 @@ class DatasetApi(DatasetApiResource):
|
||||
raise DatasetInUseError()
|
||||
|
||||
|
||||
class DocumentStatusApi(DatasetApiResource):
|
||||
"""Resource for batch document status operations."""
|
||||
|
||||
def patch(self, tenant_id, dataset_id, action):
|
||||
"""
|
||||
Batch update document status.
|
||||
|
||||
Args:
|
||||
tenant_id: tenant id
|
||||
dataset_id: dataset id
|
||||
action: action to perform (enable, disable, archive, un_archive)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with a key 'result' and a value 'success'
|
||||
int: HTTP status code 200 indicating that the operation was successful.
|
||||
|
||||
Raises:
|
||||
NotFound: If the dataset with the given ID does not exist.
|
||||
Forbidden: If the user does not have permission.
|
||||
InvalidActionError: If the action is invalid or cannot be performed.
|
||||
"""
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check user's permission
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Check dataset model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
# Get document IDs from request body
|
||||
data = request.get_json()
|
||||
document_ids = data.get("document_ids", [])
|
||||
|
||||
try:
|
||||
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
|
||||
except services.errors.document.DocumentIndexingError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
except ValueError as e:
|
||||
raise InvalidActionError(str(e))
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DatasetTagsApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
@marshal_with(tag_fields)
|
||||
def get(self, _, dataset_id):
|
||||
"""Get all knowledge type tags."""
|
||||
tags = TagService.get_tags("knowledge", current_user.current_tenant_id)
|
||||
|
||||
return tags, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
"""Add a knowledge type tag."""
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=DatasetTagsApi._validate_tag_name,
|
||||
)
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.update_tags(args, args.get("tag_id"))
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(args.get("tag_id"))
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
|
||||
|
||||
return response, 200
|
||||
|
||||
@validate_dataset_token
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
args = parser.parse_args()
|
||||
TagService.delete_tag(args.get("tag_id"))
|
||||
|
||||
return 204
|
||||
|
||||
@staticmethod
|
||||
def _validate_tag_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required."
|
||||
)
|
||||
parser.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.is_editor or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
class DatasetTagsBindingStatusApi(DatasetApiResource):
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id))
|
||||
tags_list = [{"id": tag.id, "name": tag.name} for tag in tags]
|
||||
response = {"data": tags_list, "total": len(tags)}
|
||||
return response, 200
|
||||
|
||||
|
||||
api.add_resource(DatasetListApi, "/datasets")
|
||||
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
|
||||
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>")
|
||||
api.add_resource(DatasetTagsApi, "/datasets/tags")
|
||||
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
|
||||
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
|
||||
api.add_resource(DatasetTagsBindingStatusApi, "/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@@ -19,11 +19,7 @@ from controllers.service_api.dataset.error import (
|
||||
ArchivedDocumentImmutableError,
|
||||
DocumentIndexingError,
|
||||
)
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
@@ -39,7 +35,6 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -104,7 +99,6 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by text."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -164,7 +158,6 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_resource_check("documents", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by upload file."""
|
||||
args = {}
|
||||
@@ -182,11 +175,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
if not dataset.indexing_technique and not args.get("indexing_technique"):
|
||||
raise ValueError("indexing_technique is required.")
|
||||
args["indexing_technique"] = indexing_technique
|
||||
|
||||
# save file info
|
||||
file = request.files["file"]
|
||||
@@ -216,16 +206,12 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
DocumentService.document_create_args_validate(knowledge_config)
|
||||
|
||||
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
|
||||
if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule:
|
||||
raise ValueError("process_rule is required.")
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(
|
||||
dataset=dataset,
|
||||
knowledge_config=knowledge_config,
|
||||
account=dataset.created_by_account,
|
||||
dataset_process_rule=dataset_process_rule,
|
||||
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
|
||||
created_from="api",
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@@ -239,7 +225,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Update document by upload file."""
|
||||
args = {}
|
||||
@@ -310,7 +295,6 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DocumentDeleteApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id):
|
||||
"""Delete document."""
|
||||
document_id = str(document_id)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
|
||||
|
||||
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
@@ -14,7 +14,6 @@ from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
|
||||
@@ -40,7 +39,6 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
|
||||
@@ -56,7 +54,6 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -76,7 +73,6 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, action):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -92,7 +88,6 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@@ -8,7 +8,6 @@ from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_knowledge_limit_check,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
@@ -36,7 +35,6 @@ class SegmentApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
@@ -141,7 +139,6 @@ class SegmentApi(DatasetApiResource):
|
||||
|
||||
|
||||
class DatasetSegmentApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -165,7 +162,6 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
return 204
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
@@ -212,35 +208,12 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id, document_id, segment_id):
|
||||
"""Create child chunk."""
|
||||
# check dataset
|
||||
@@ -337,7 +310,6 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
"""Resource for updating child chunks."""
|
||||
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
"""Delete child chunk."""
|
||||
# check dataset
|
||||
@@ -376,7 +348,6 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
|
||||
"""Update child chunk."""
|
||||
# check dataset
|
||||
|
||||
@@ -15,17 +15,4 @@ api.add_resource(FileApi, "/files/upload")
|
||||
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||
api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
|
||||
|
||||
from . import (
|
||||
app,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
feature,
|
||||
forgot_password,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
)
|
||||
from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow
|
||||
|
||||
@@ -10,8 +10,6 @@ from libs.passport import PassportService
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
class AppParameterApi(WebApiResource):
|
||||
@@ -48,22 +46,10 @@ class AppMeta(WebApiResource):
|
||||
class AppAccessMode(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=False, location="args")
|
||||
parser.add_argument("appCode", type=str, required=False, location="args")
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"accessMode": "public"}
|
||||
|
||||
app_id = args.get("appId")
|
||||
if args.get("appCode"):
|
||||
app_code = args["appCode"]
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
|
||||
if not app_id:
|
||||
raise ValueError("appId or appCode must be provided")
|
||||
|
||||
app_id = args["appId"]
|
||||
res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
|
||||
return {"accessMode": res.access_mode}
|
||||
@@ -89,10 +75,6 @@ class AppWebAuthPermission(Resource):
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
features = FeatureService.get_system_features()
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"result": True}
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
@@ -100,9 +82,7 @@ class AppWebAuthPermission(Resource):
|
||||
app_id = args["appId"]
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
|
||||
res = True
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
return {"result": res}
|
||||
|
||||
|
||||
|
||||
@@ -139,13 +139,3 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
error_code = "rate_limit_error"
|
||||
description = "Rate Limit Error"
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
|
||||
class InvalidArgumentError(BaseHTTPException):
|
||||
error_code = "invalid_param"
|
||||
code = 400
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
import base64
|
||||
import secrets
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console.auth.error import (
|
||||
EmailCodeError,
|
||||
EmailPasswordResetLimitError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from controllers.console.error import AccountNotFound, EmailSendIpLimitError
|
||||
from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import Account
|
||||
from services.account_service import AccountService
|
||||
|
||||
|
||||
class ForgotPasswordSendEmailApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("language", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
|
||||
if args["language"] is not None and args["language"] == "zh-Hans":
|
||||
language = "zh-Hans"
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
|
||||
token = None
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
else:
|
||||
token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
class ForgotPasswordCheckApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
parser.add_argument("code", type=str, required=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
|
||||
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"])
|
||||
if is_forgot_password_error_rate_limit:
|
||||
raise EmailPasswordResetLimitError()
|
||||
|
||||
token_data = AccountService.get_reset_password_data(args["token"])
|
||||
if token_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
if user_email != token_data.get("email"):
|
||||
raise InvalidEmailError()
|
||||
|
||||
if args["code"] != token_data.get("code"):
|
||||
AccountService.add_forgot_password_error_rate_limit(args["email"])
|
||||
raise EmailCodeError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_reset_password_token(
|
||||
user_email, code=args["code"], additional_data={"phase": "reset"}
|
||||
)
|
||||
|
||||
AccountService.reset_forgot_password_error_rate_limit(args["email"])
|
||||
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
|
||||
|
||||
|
||||
class ForgotPasswordResetApi(Resource):
|
||||
@only_edition_enterprise
|
||||
@setup_required
|
||||
@email_password_login_enabled
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate passwords match
|
||||
if args["new_password"] != args["password_confirm"]:
|
||||
raise PasswordMismatchError()
|
||||
|
||||
# Validate token and get reset data
|
||||
reset_data = AccountService.get_reset_password_data(args["token"])
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
# Must use token in reset phase
|
||||
if reset_data.get("phase", "") != "reset":
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Revoke token to prevent reuse
|
||||
AccountService.revoke_reset_password_token(args["token"])
|
||||
|
||||
# Generate secure salt and hash password
|
||||
salt = secrets.token_bytes(16)
|
||||
password_hashed = hash_password(args["new_password"], salt)
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
else:
|
||||
raise AccountNotFound()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account, password_hashed, salt, session):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
|
||||
api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password")
|
||||
api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity")
|
||||
api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets")
|
||||
@@ -1,11 +1,12 @@
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from jwt import InvalidTokenError # type: ignore
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError
|
||||
from controllers.console.error import AccountBannedError, AccountNotFound
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import api
|
||||
from controllers.console.wraps import setup_required
|
||||
from libs.helper import email
|
||||
from libs.password import valid_password
|
||||
from services.account_service import AccountService
|
||||
@@ -15,8 +16,6 @@ from services.webapp_auth_service import WebAppAuthService
|
||||
class LoginApi(Resource):
|
||||
"""Resource for web app email/password login."""
|
||||
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
parser = reqparse.RequestParser()
|
||||
@@ -24,6 +23,10 @@ class LoginApi(Resource):
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(args["email"], args["password"])
|
||||
except services.errors.account.AccountLoginError:
|
||||
@@ -33,8 +36,12 @@ class LoginApi(Resource):
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
raise AccountNotFound()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
return {"result": "success", "token": token}
|
||||
|
||||
|
||||
# class LogoutApi(Resource):
|
||||
@@ -49,7 +56,6 @@ class LoginApi(Resource):
|
||||
|
||||
class EmailCodeLoginSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
@@ -72,7 +78,6 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
|
||||
class EmailCodeLoginApi(Resource):
|
||||
@setup_required
|
||||
@only_edition_enterprise
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=str, required=True, location="json")
|
||||
@@ -81,6 +86,9 @@ class EmailCodeLoginApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
user_email = args["email"]
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
if app_code is None:
|
||||
raise BadRequest("X-App-Code header is missing.")
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
|
||||
if token_data is None:
|
||||
@@ -97,12 +105,16 @@ class EmailCodeLoginApi(Resource):
|
||||
if not account:
|
||||
raise AccountNotFound()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code)
|
||||
|
||||
end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code)
|
||||
|
||||
token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
return {"result": "success", "token": token}
|
||||
|
||||
|
||||
api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LoginApi, "/login")
|
||||
# api.add_resource(LogoutApi, "/logout")
|
||||
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
|
||||
# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import api
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
@@ -13,7 +11,6 @@ from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
|
||||
|
||||
class PassportResource(Resource):
|
||||
@@ -23,19 +20,10 @@ class PassportResource(Resource):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
user_id = request.args.get("user_id")
|
||||
web_app_access_token = request.args.get("web_app_access_token")
|
||||
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(
|
||||
app_code=app_code, enterprise_user_decoded=enterprise_user_decoded
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
@@ -96,128 +84,6 @@ class PassportResource(Resource):
|
||||
api.add_resource(PassportResource, "/passport")
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
user_id = enterprise_user_decoded.get("user_id")
|
||||
end_user_id = enterprise_user_decoded.get("end_user_id")
|
||||
session_id = enterprise_user_decoded.get("session_id")
|
||||
user_auth_type = enterprise_user_decoded.get("auth_type")
|
||||
if not user_auth_type:
|
||||
raise Unauthorized("Missing auth_type in the token.")
|
||||
|
||||
site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
|
||||
if not site:
|
||||
raise NotFound()
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == site.app_id).first()
|
||||
if not app_model or app_model.status != "normal" or not app_model.enable_site:
|
||||
raise NotFound()
|
||||
|
||||
app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code)
|
||||
|
||||
if app_auth_type == WebAppAuthType.PUBLIC:
|
||||
return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded)
|
||||
elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external":
|
||||
raise WebAppAuthRequiredError("Please login as external user.")
|
||||
elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal":
|
||||
raise WebAppAuthRequiredError("Please login as internal user.")
|
||||
|
||||
end_user = None
|
||||
if end_user_id:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if session_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser)
|
||||
.filter(
|
||||
EndUser.session_id == session_id,
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not end_user:
|
||||
if not session_id:
|
||||
raise NotFound("Missing session_id for existing web user.")
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=session_id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
exp = int(exp_dt.timestamp())
|
||||
payload = {
|
||||
"iss": site.id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user.id,
|
||||
"auth_type": user_auth_type,
|
||||
"granted_at": int(datetime.now(UTC).timestamp()),
|
||||
"token_source": "webapp",
|
||||
"exp": exp,
|
||||
}
|
||||
token: str = PassportService().issue(payload)
|
||||
return {
|
||||
"access_token": token,
|
||||
}
|
||||
|
||||
|
||||
def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
user_id = token_decoded.get("user_id")
|
||||
end_user = None
|
||||
if user_id:
|
||||
end_user = (
|
||||
db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
|
||||
)
|
||||
|
||||
if not end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
is_anonymous=True,
|
||||
session_id=generate_session_id(),
|
||||
)
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
payload = {
|
||||
"iss": site.app_id,
|
||||
"sub": "Web API Passport",
|
||||
"app_id": site.app_id,
|
||||
"app_code": site.code,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
"access_token": tk,
|
||||
}
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
"""
|
||||
Generate a unique session ID.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
@@ -9,9 +8,8 @@ from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequire
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from models.model import App, EndUser, Site
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
def validate_jwt_token(view=None):
|
||||
@@ -47,8 +45,7 @@ def decode_jwt_token():
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
app_model = db.session.query(App).filter(App.id == app_id).first()
|
||||
app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first()
|
||||
site = db.session.query(Site).filter(Site.code == app_code).first()
|
||||
if not app_model:
|
||||
raise NotFound()
|
||||
@@ -56,30 +53,23 @@ def decode_jwt_token():
|
||||
raise BadRequest("Site URL is no longer valid.")
|
||||
if app_model.enable_site is False:
|
||||
raise BadRequest("Site is disabled.")
|
||||
end_user_id = decoded.get("end_user_id")
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
|
||||
if not end_user:
|
||||
raise NotFound()
|
||||
|
||||
# for enterprise webapp auth
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
if system_features.webapp_auth.enabled:
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public"
|
||||
)
|
||||
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
_validate_user_accessibility(
|
||||
decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings
|
||||
)
|
||||
_validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled)
|
||||
|
||||
return app_model, end_user
|
||||
except Unauthorized as e:
|
||||
if system_features.webapp_auth.enabled:
|
||||
if not app_code:
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
|
||||
)
|
||||
@@ -105,41 +95,15 @@ def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_au
|
||||
raise Unauthorized("webapp token expired.")
|
||||
|
||||
|
||||
def _validate_user_accessibility(
|
||||
decoded,
|
||||
app_code,
|
||||
app_web_auth_enabled: bool,
|
||||
system_webapp_auth_enabled: bool,
|
||||
webapp_settings: WebAppSettings | None,
|
||||
):
|
||||
def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool):
|
||||
if system_webapp_auth_enabled and app_web_auth_enabled:
|
||||
# Check if the user is allowed to access the web app
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
if not webapp_settings:
|
||||
raise WebAppAuthRequiredError("Web app settings not found.")
|
||||
|
||||
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
auth_type = decoded.get("auth_type")
|
||||
granted_at = decoded.get("granted_at")
|
||||
if not auth_type:
|
||||
raise WebAppAuthAccessDeniedError("Missing auth_type in the token.")
|
||||
if not granted_at:
|
||||
raise WebAppAuthAccessDeniedError("Missing granted_at in the token.")
|
||||
# check if sso has been updated
|
||||
if auth_type == "external":
|
||||
last_update_time = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
elif auth_type == "internal":
|
||||
last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time:
|
||||
raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.")
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
|
||||
class WebApiResource(Resource):
|
||||
|
||||
@@ -63,7 +63,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
@@ -82,7 +82,7 @@ class AgentEntity(BaseModel):
|
||||
strategy: Strategy
|
||||
prompt: Optional[AgentPromptEntity] = None
|
||||
tools: Optional[list[AgentToolEntity]] = None
|
||||
max_iteration: int = 10
|
||||
max_iteration: int = 5
|
||||
|
||||
|
||||
class AgentInvokeMessage(ToolInvokeMessage):
|
||||
|
||||
@@ -48,7 +48,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
assert app_config.agent
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
|
||||
@@ -75,7 +75,7 @@ class AgentConfigManager:
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -138,11 +138,14 @@ class DatasetConfigManager:
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
|
||||
@@ -70,7 +70,7 @@ class ModelConfigConverter:
|
||||
if not model_mode:
|
||||
model_mode = LLMMode.CHAT.value
|
||||
if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE):
|
||||
model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@@ -104,7 +104,6 @@ class VariableEntity(BaseModel):
|
||||
Variable Entity.
|
||||
"""
|
||||
|
||||
# `variable` records the name of the variable in user inputs.
|
||||
variable: str
|
||||
label: str
|
||||
description: str = ""
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -27,17 +27,14 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,11 +115,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
@@ -268,13 +260,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@@ -285,7 +270,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -351,13 +335,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
@@ -368,7 +345,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@@ -382,7 +358,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -391,7 +366,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
@@ -425,18 +399,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": context,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -463,7 +439,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -474,9 +449,24 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
@@ -490,7 +480,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
dialogue_count=self._dialogue_count,
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
@@ -19,7 +19,6 @@ from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
@@ -41,17 +40,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
dialogue_count: int,
|
||||
variable_loader: VariableLoader,
|
||||
) -> None:
|
||||
super().__init__(queue_manager, variable_loader)
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
self._dialogue_count = dialogue_count
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
@@ -144,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
@@ -56,23 +57,26 @@ from core.app.entities.task_entities import (
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,14 +126,8 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
@@ -139,7 +137,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
self._message_cycle_manager = MessageCycleManage(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
|
||||
@@ -160,7 +158,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
:return:
|
||||
"""
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query
|
||||
)
|
||||
|
||||
@@ -304,12 +302,15 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
message = self._get_message(session=session)
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {self._message_id}")
|
||||
message.workflow_run_id = workflow_execution.id_
|
||||
message.workflow_run_id = workflow_execution.id
|
||||
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
@@ -549,7 +550,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error_message=event.error,
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
@@ -575,7 +576,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.STOPPED,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error_message=event.get_stop_reason(),
|
||||
conversation_id=self._conversation_id,
|
||||
trace_manager=trace_manager,
|
||||
@@ -603,18 +604,22 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
self._message_cycle_manager._handle_retriever_resources(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
self._message_cycle_manager._handle_annotation_reply(event)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = self._get_message(session=session)
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
session.commit()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
@@ -631,12 +636,12 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
yield self._message_cycle_manager._message_to_stream_response(
|
||||
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=event.text, reason=event.reason
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
@@ -648,7 +653,7 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
yield self._message_cycle_manager._message_replace_to_stream_response(
|
||||
answer=output_moderation_answer,
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
@@ -677,7 +682,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
message = self._get_message(session=session)
|
||||
message.answer = self._task_state.answer
|
||||
message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
message_files = [
|
||||
MessageFile(
|
||||
message_id=message.id,
|
||||
@@ -705,9 +712,9 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
message.answer_price_unit = usage.completion_price_unit
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
self._task_state.metadata.usage = usage
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(usage)
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
@@ -718,16 +725,18 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
extras = self._task_state.metadata.model_dump()
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
|
||||
if self._task_state.metadata.annotation_reply:
|
||||
del extras["annotation_reply"]
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
files=self._recorded_files,
|
||||
metadata=extras,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -23,7 +23,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
from services.errors.message import MessageNotExistsError
|
||||
@@ -124,11 +123,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
@@ -188,17 +182,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=context,
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -232,9 +229,24 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param message_id: message ID
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
@@ -34,8 +33,6 @@ from models.model import App, AppMode, Message, MessageAnnotation
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
@@ -301,7 +298,7 @@ class AppRunner:
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
|
||||
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
@@ -320,28 +317,18 @@ class AppRunner:
|
||||
else:
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
message = result.delta.message
|
||||
if isinstance(message.content, str):
|
||||
text += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for content in message.content:
|
||||
if not isinstance(content, str):
|
||||
# TODO(QuantumGhost): Add multimodal output support for easy ui.
|
||||
_logger.warning("received multimodal output, type=%s", type(content))
|
||||
text += content.data
|
||||
else:
|
||||
text += content # failback to str
|
||||
text += result.delta.message.content
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if usage is None:
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
llm_result = LLMResult(
|
||||
|
||||
@@ -115,11 +115,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
|
||||
@@ -44,15 +44,15 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.entities.node_execution_entities import NodeExecution
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
Account,
|
||||
CreatorUserRole,
|
||||
EndUser,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
@@ -73,10 +73,11 @@ class WorkflowResponseConverter:
|
||||
) -> WorkflowStartStreamResponse:
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
id=workflow_execution.id,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
inputs=workflow_execution.inputs,
|
||||
created_at=int(workflow_execution.started_at.timestamp()),
|
||||
),
|
||||
@@ -90,7 +91,7 @@ class WorkflowResponseConverter:
|
||||
workflow_execution: WorkflowExecution,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
assert workflow_run is not None
|
||||
if workflow_run.created_by_role == CreatorUserRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
@@ -121,12 +122,13 @@ class WorkflowResponseConverter:
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution.id_,
|
||||
workflow_run_id=workflow_execution.id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_execution.id_,
|
||||
id=workflow_execution.id,
|
||||
workflow_id=workflow_execution.workflow_id,
|
||||
sequence_number=workflow_execution.sequence_number,
|
||||
status=workflow_execution.status,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
|
||||
outputs=workflow_execution.outputs,
|
||||
error=workflow_execution.error_message,
|
||||
elapsed_time=workflow_execution.elapsed_time,
|
||||
total_tokens=workflow_execution.total_tokens,
|
||||
@@ -144,16 +146,16 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
@@ -194,20 +196,18 @@ class WorkflowResponseConverter:
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
@@ -217,7 +217,7 @@ class WorkflowResponseConverter:
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
outputs=workflow_node_execution.outputs,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
@@ -239,20 +239,18 @@ class WorkflowResponseConverter:
|
||||
*,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
workflow_node_execution: NodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_execution_id:
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_execution_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
@@ -262,7 +260,7 @@ class WorkflowResponseConverter:
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs,
|
||||
process_data=workflow_node_execution.process_data,
|
||||
outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
|
||||
outputs=workflow_node_execution.outputs,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
@@ -381,7 +379,6 @@ class WorkflowResponseConverter:
|
||||
workflow_execution_id: str,
|
||||
event: QueueIterationCompletedEvent,
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
json_converter = WorkflowRuntimeTypeConverter()
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_execution_id,
|
||||
@@ -390,7 +387,7 @@ class WorkflowResponseConverter:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=json_converter.to_json_encodable(event.outputs),
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
@@ -469,7 +466,7 @@ class WorkflowResponseConverter:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
|
||||
@@ -101,11 +101,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
@@ -25,15 +25,12 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,11 +93,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
# parse files
|
||||
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
|
||||
# for better separation of concerns.
|
||||
#
|
||||
# For implementation reference, see the `_parse_file` function and
|
||||
# `DraftWorkflowNodeRunApi` class which handle this properly.
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
@@ -140,7 +132,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
@@ -193,7 +185,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -203,7 +194,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
@@ -219,17 +209,19 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": context,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
"variable_loader": variable_loader,
|
||||
},
|
||||
)
|
||||
@copy_current_request_context
|
||||
def worker_with_context():
|
||||
# Run the worker within the copied context
|
||||
return context.run(
|
||||
self._generate_worker,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
context=context,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -287,7 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -312,13 +304,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
@@ -329,7 +314,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -371,7 +355,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -396,13 +380,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
draft_var_srv = WorkflowDraftVariableService(db.session())
|
||||
draft_var_srv.prefill_conversation_variable_default_values(workflow)
|
||||
var_loader = DraftVarLoader(
|
||||
engine=db.engine,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
tenant_id=application_generate_entity.app_config.tenant_id,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
@@ -412,7 +390,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
@@ -421,7 +398,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -432,15 +408,29 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
var.set(val)
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# FIXME(-LAN-): Save current user before entering new app context
|
||||
from flask import g
|
||||
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
from flask import g
|
||||
|
||||
g._login_user = saved_user
|
||||
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
variable_loader=variable_loader,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
||||
@@ -12,7 +12,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.variable_loader import VariableLoader
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
@@ -31,7 +30,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -39,13 +37,10 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
super().__init__(queue_manager, variable_loader)
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
@@ -100,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
|
||||
@@ -50,15 +50,16 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.entities.workflow_execution_entities import WorkflowExecution
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.workflow_cycle_manager import WorkflowCycleManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@@ -68,6 +69,7 @@ from models.workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -112,14 +114,8 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
SystemVariableKey.USER_ID: user_session_id,
|
||||
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: workflow.id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
|
||||
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
|
||||
},
|
||||
workflow_info=CycleManagerWorkflowInfo(
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType(workflow.type),
|
||||
version=workflow.version,
|
||||
graph_data=workflow.graph_dict,
|
||||
),
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
)
|
||||
@@ -129,7 +125,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
)
|
||||
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._workflow_run_id = ""
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
@@ -268,13 +266,17 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
|
||||
self._workflow_run_id = workflow_execution.id_
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# init workflow run
|
||||
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start(
|
||||
session=session,
|
||||
workflow_id=self._workflow_id,
|
||||
)
|
||||
self._workflow_run_id = workflow_execution.id
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_execution=workflow_execution,
|
||||
)
|
||||
|
||||
yield start_resp
|
||||
elif isinstance(
|
||||
@@ -509,9 +511,9 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowExecutionStatus.FAILED
|
||||
status=WorkflowRunStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowExecutionStatus.STOPPED,
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
error_message=event.error
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else event.get_stop_reason(),
|
||||
@@ -540,6 +542,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
delta_text, from_variable_selector=event.from_variable_selector
|
||||
)
|
||||
@@ -554,7 +557,7 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
|
||||
workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id))
|
||||
assert workflow_run is not None
|
||||
invoke_from = self._application_generate_entity.invoke_from
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
@@ -31,11 +29,10 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
AgentLogEvent,
|
||||
BaseNodeEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
@@ -65,23 +62,15 @@ from core.workflow.graph_engine.entities.event import (
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.workflow_draft_variable_service import (
|
||||
DraftVariableSaver,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
self._variable_loader = variable_loader
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
raise NotImplementedError("not implemented")
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
@@ -184,13 +173,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
load_into_variable_pool(
|
||||
variable_loader=self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
@@ -280,12 +262,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
load_into_variable_pool(
|
||||
self._variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
@@ -319,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
inputs: Mapping[str, Any] | None = {}
|
||||
process_data: Mapping[str, Any] | None = {}
|
||||
outputs: Mapping[str, Any] | None = {}
|
||||
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {}
|
||||
execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {}
|
||||
if node_run_result:
|
||||
inputs = node_run_result.inputs
|
||||
process_data = node_run_result.process_data
|
||||
@@ -400,8 +376,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
@@ -464,8 +438,6 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
self._save_draft_var_for_event(event)
|
||||
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
@@ -718,30 +690,3 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
def _save_draft_var_for_event(self, event: BaseNodeEvent):
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if run_result is None:
|
||||
return
|
||||
process_data = run_result.process_data
|
||||
outputs = run_result.outputs
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
session=session,
|
||||
app_id=self._get_app_id(),
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
|
||||
invoke_from=self.queue_manager._invoke_from,
|
||||
node_execution_id=event.id,
|
||||
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
|
||||
)
|
||||
draft_var_saver.save(process_data=process_data, outputs=outputs)
|
||||
|
||||
|
||||
def _remove_first_element_from_variable_string(key: str) -> str:
|
||||
"""
|
||||
Remove the first element from the prefix.
|
||||
"""
|
||||
prefix, remaining = key.split(".", maxsplit=1)
|
||||
return remaining
|
||||
|
||||
@@ -17,24 +17,9 @@ class InvokeFrom(Enum):
|
||||
Invoke From.
|
||||
"""
|
||||
|
||||
# SERVICE_API indicates that this invocation is from an API call to Dify app.
|
||||
#
|
||||
# Description of service api in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
|
||||
SERVICE_API = "service-api"
|
||||
|
||||
# WEB_APP indicates that this invocation is from
|
||||
# the web app of the workflow (or chatflow).
|
||||
#
|
||||
# Description of web app in Dify docs:
|
||||
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
|
||||
WEB_APP = "web-app"
|
||||
|
||||
# EXPLORE indicates that this invocation is from
|
||||
# the workflow (or chatflow) explore page.
|
||||
EXPLORE = "explore"
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
@@ -91,8 +76,6 @@ class AppGenerateEntity(BaseModel):
|
||||
App Generate Entity.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
@@ -116,6 +99,9 @@ class AppGenerateEntity(BaseModel):
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
@@ -219,7 +205,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
workflow_execution_id: str
|
||||
workflow_run_id: str
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Any, Optional
|
||||
@@ -6,9 +6,7 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
@@ -284,7 +282,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata]
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
@@ -414,7 +412,7 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
@@ -448,7 +446,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
@@ -482,7 +480,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@@ -515,7 +513,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@@ -548,7 +546,7 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@@ -581,7 +579,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
@@ -2,29 +2,12 @@ from collections.abc import Mapping, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class AnnotationReply(BaseModel):
|
||||
id: str
|
||||
account: AnnotationReplyAccount
|
||||
|
||||
|
||||
class TaskStateMetadata(BaseModel):
|
||||
annotation_reply: AnnotationReply | None = None
|
||||
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list)
|
||||
usage: LLMUsage | None = None
|
||||
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
@@ -32,7 +15,7 @@ class TaskState(BaseModel):
|
||||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata)
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class EasyUITaskState(TaskState):
|
||||
@@ -206,6 +189,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
|
||||
@@ -226,6 +210,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
status: str
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
@@ -322,7 +307,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
@@ -391,7 +376,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -42,15 +43,15 @@ from core.app.entities.task_entities import (
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@@ -62,7 +63,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
@@ -103,11 +104,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
)
|
||||
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity,
|
||||
task_state=self._task_state,
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread: Optional[Thread] = None
|
||||
|
||||
def process(
|
||||
@@ -119,7 +115,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
]:
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
|
||||
)
|
||||
|
||||
@@ -140,9 +136,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
if self._conversation_mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
@@ -281,9 +277,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(
|
||||
answer=output_moderation_answer
|
||||
)
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Save message
|
||||
@@ -292,9 +286,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message_end_resp = self._message_end_to_stream_response()
|
||||
yield message_end_resp
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
self._handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._message_cycle_manager.handle_annotation_reply(event)
|
||||
annotation = self._handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
@@ -302,7 +296,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
if agent_thought_response is not None:
|
||||
yield agent_thought_response
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_cycle_manager.message_file_to_stream_response(event)
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
@@ -310,23 +304,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
delta_text = ""
|
||||
for content in chunk.delta.message.content:
|
||||
logger.debug(
|
||||
"The content type %s in LLM chunk delta message content.: %r", type(content), content
|
||||
)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
delta_text += content.data
|
||||
elif isinstance(content, str):
|
||||
delta_text += content # failback to str
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported content type %s in LLM chunk delta message content.: %r",
|
||||
type(content),
|
||||
content,
|
||||
)
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
@@ -341,7 +318,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
yield self._message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
)
|
||||
@@ -351,7 +328,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message_id=self._message_id,
|
||||
)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text)
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
@@ -395,7 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
message.total_price = usage.total_price
|
||||
message.currency = usage.currency
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -444,12 +423,16 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
metadata=extras.get("metadata", {}),
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
|
||||
@@ -17,8 +17,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AnnotationReply,
|
||||
AnnotationReplyAccount,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
@@ -32,7 +30,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManager:
|
||||
class MessageCycleManage:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -47,7 +45,7 @@ class MessageCycleManager:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation_id: conversation id
|
||||
@@ -104,7 +102,7 @@ class MessageCycleManager:
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
@@ -113,28 +111,25 @@ class MessageCycleManager:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata.annotation_reply = AnnotationReply(
|
||||
id=annotation.id,
|
||||
account=AnnotationReplyAccount(
|
||||
id=annotation.account_id,
|
||||
name=account.name if account else "Dify user",
|
||||
),
|
||||
)
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
|
||||
def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
@@ -171,7 +166,7 @@ class MessageCycleManager:
|
||||
|
||||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
def _message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
@@ -187,7 +182,7 @@ class MessageCycleManager:
|
||||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
@@ -1,10 +1,8 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
@@ -87,8 +85,7 @@ class DatasetIndexToolCallbackHandler:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# TODO(-LAN-): Improve type check
|
||||
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
|
||||
def return_retriever_resource_info(self, resource: list):
|
||||
"""Handle return_retriever_resource_info."""
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
|
||||
@@ -55,25 +55,6 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
Check model status and raise ValueError if not active.
|
||||
|
||||
:raises ValueError: When model status is not active, with a descriptive message
|
||||
"""
|
||||
if self.status == ModelStatus.ACTIVE:
|
||||
return
|
||||
|
||||
error_messages = {
|
||||
ModelStatus.NO_CONFIGURE: "Model is not configured",
|
||||
ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded",
|
||||
ModelStatus.NO_PERMISSION: "No permission to use this model",
|
||||
ModelStatus.DISABLED: "Model is disabled",
|
||||
}
|
||||
|
||||
if self.status in error_messages:
|
||||
raise ValueError(error_messages[self.status])
|
||||
|
||||
|
||||
class ModelWithProviderEntity(ProviderModelWithStatusEntity):
|
||||
"""
|
||||
|
||||
@@ -41,53 +41,45 @@ class Extensible:
|
||||
extensions = []
|
||||
position_map: dict[str, int] = {}
|
||||
|
||||
# Get the package name from the module path
|
||||
package_name = ".".join(cls.__module__.split(".")[:-1])
|
||||
# get the path of the current class
|
||||
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py")
|
||||
current_dir_path = os.path.dirname(current_path)
|
||||
|
||||
try:
|
||||
# Get package directory path
|
||||
package_spec = importlib.util.find_spec(package_name)
|
||||
if not package_spec or not package_spec.origin:
|
||||
raise ImportError(f"Could not find package {package_name}")
|
||||
# traverse subdirectories
|
||||
for subdir_name in os.listdir(current_dir_path):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
package_dir = os.path.dirname(package_spec.origin)
|
||||
|
||||
# Traverse subdirectories
|
||||
for subdir_name in os.listdir(package_dir):
|
||||
if subdir_name.startswith("__"):
|
||||
continue
|
||||
|
||||
subdir_path = os.path.join(package_dir, subdir_name)
|
||||
if not os.path.isdir(subdir_path):
|
||||
continue
|
||||
|
||||
extension_name = subdir_name
|
||||
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||
extension_name = subdir_name
|
||||
if os.path.isdir(subdir_path):
|
||||
file_names = os.listdir(subdir_path)
|
||||
|
||||
# Check for extension module file
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Check for builtin flag and position
|
||||
# is builtin extension, builtin extension
|
||||
# in the front-end page and business logic, there are special treatments.
|
||||
builtin = False
|
||||
# default position is 0 can not be None for sort_to_dict_by_position_map
|
||||
position = 0
|
||||
if "__builtin__" in file_names:
|
||||
builtin = True
|
||||
|
||||
builtin_file_path = os.path.join(subdir_path, "__builtin__")
|
||||
if os.path.exists(builtin_file_path):
|
||||
position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip())
|
||||
position_map[extension_name] = position
|
||||
|
||||
# Import the extension module
|
||||
module_name = f"{package_name}.{extension_name}.{extension_name}"
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
if (extension_name + ".py") not in file_names:
|
||||
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||
py_path = os.path.join(subdir_path, extension_name + ".py")
|
||||
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||
if not spec or not spec.loader:
|
||||
raise ImportError(f"Failed to load module {module_name}")
|
||||
raise Exception(f"Failed to load module {extension_name} from {py_path}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# Find extension class
|
||||
extension_class = None
|
||||
for name, obj in vars(mod).items():
|
||||
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||
@@ -95,21 +87,21 @@ class Extensible:
|
||||
break
|
||||
|
||||
if not extension_class:
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.")
|
||||
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||
continue
|
||||
|
||||
# Load schema if not builtin
|
||||
json_data: dict[str, Any] = {}
|
||||
if not builtin:
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
if not os.path.exists(json_path):
|
||||
if "schema.json" not in file_names:
|
||||
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||
continue
|
||||
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
json_path = os.path.join(subdir_path, "schema.json")
|
||||
json_data = {}
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
# Create extension
|
||||
extensions.append(
|
||||
ModuleExtension(
|
||||
extension_class=extension_class,
|
||||
@@ -121,11 +113,6 @@ class Extensible:
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error scanning extensions")
|
||||
raise
|
||||
|
||||
# Sort extensions by position
|
||||
sorted_extensions = sort_to_dict_by_position_map(
|
||||
position_map=position_map, data=extensions, name_func=lambda x: x.name
|
||||
)
|
||||
|
||||
@@ -1,11 +1 @@
|
||||
from typing import Any
|
||||
|
||||
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
|
||||
# comparing `dify_model_identity` with constants throughout the codebase, extract
|
||||
# this logic into a dedicated function. This would encapsulate the implementation
|
||||
# details of how different variable types are identified.
|
||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||
|
||||
|
||||
def maybe_file_object(o: Any) -> bool:
|
||||
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
|
||||
|
||||
@@ -15,7 +15,6 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat
|
||||
from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT))
|
||||
|
||||
|
||||
class CodeExecutionError(Exception):
|
||||
@@ -65,7 +64,7 @@ class CodeExecutor:
|
||||
:param code: code
|
||||
:return:
|
||||
"""
|
||||
url = code_execution_endpoint_url / "v1" / "sandbox" / "run"
|
||||
url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run"
|
||||
|
||||
headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY}
|
||||
|
||||
|
||||
@@ -7,28 +7,29 @@ from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
|
||||
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str) -> str:
|
||||
return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier))
|
||||
def get_plugin_pkg_url(plugin_unique_identifier: str):
|
||||
return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query(
|
||||
unique_identifier=plugin_unique_identifier
|
||||
)
|
||||
|
||||
|
||||
def download_plugin_pkg(plugin_unique_identifier: str):
|
||||
return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
url = str(get_plugin_pkg_url(plugin_unique_identifier))
|
||||
return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import secrets
|
||||
import random
|
||||
from typing import cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
@@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt
|
||||
if len(text_chunks) == 0:
|
||||
return True
|
||||
|
||||
text_chunk = secrets.choice(text_chunks)
|
||||
text_chunk = random.choice(text_chunks)
|
||||
|
||||
try:
|
||||
model_provider_factory = ModelProviderFactory(tenant_id)
|
||||
|
||||
@@ -1,20 +1,61 @@
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh
|
||||
CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”.
|
||||
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
|
||||
CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
|
||||
Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
|
||||
ENSURE your output is in the SAME language as the user's input!
|
||||
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
|
||||
Your output MUST be a valid JSON.
|
||||
|
||||
1. Detect Input Language
|
||||
Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.).
|
||||
Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun.
|
||||
|
||||
2. Generate Title
|
||||
- Combine Intention + Subject into a single, as-short-as-possible phrase.
|
||||
- The title must be natural, friendly, and in the same language as the input.
|
||||
- If the input is a direct question to the model, you may add an emoji at the end.
|
||||
|
||||
3. Output Format
|
||||
Return **only** a valid JSON object with these exact keys and no additional text:
|
||||
example 1:
|
||||
User Input: hi, yesterday i had some burgers.
|
||||
{
|
||||
"Language Type": "<Detected language>",
|
||||
"Your Reasoning": "<Brief explanation in that language>",
|
||||
"Your Output": "<Intention + Subject>"
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "sharing yesterday's food"
|
||||
}
|
||||
|
||||
example 2:
|
||||
User Input: hello
|
||||
{
|
||||
"Language Type": "The user's input is pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Greeting myself☺️"
|
||||
}
|
||||
|
||||
|
||||
example 3:
|
||||
User Input: why mmap file: oom
|
||||
{
|
||||
"Language Type": "The user's input is written in pure English",
|
||||
"Your Reasoning": "The language of my output must be pure English.",
|
||||
"Your Output": "Asking about the reason for mmap file: oom"
|
||||
}
|
||||
|
||||
|
||||
example 4:
|
||||
User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么?
|
||||
{
|
||||
"Language Type": "The user's input English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv"
|
||||
}
|
||||
|
||||
example 5:
|
||||
User Input: why小红的年龄is老than小明?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
|
||||
"Your Output": "询问小红和小明的年龄"
|
||||
}
|
||||
|
||||
example 6:
|
||||
User Input: yo, 你今天咋样?
|
||||
{
|
||||
"Language Type": "The user's input is English-Chinese mixed",
|
||||
"Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.",
|
||||
"Your Output": "查询今日我的状态☺️"
|
||||
}
|
||||
|
||||
User Input:
|
||||
|
||||
@@ -542,6 +542,8 @@ class LBModelManager:
|
||||
|
||||
return config
|
||||
|
||||
return None
|
||||
|
||||
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
|
||||
"""
|
||||
Cooldown model load balancing config
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user