mirror of
https://github.com/langgenius/dify.git
synced 2025-12-31 03:27:17 +00:00
Compare commits
53 Commits
fix/expose
...
chore/auto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a65d255a32 | ||
|
|
2cf0cb471f | ||
|
|
269ba6add9 | ||
|
|
78d460a6d1 | ||
|
|
3254018ddb | ||
|
|
f2b7df94d7 | ||
|
|
59fd3aad31 | ||
|
|
a3d18d43ed | ||
|
|
20cbebeef1 | ||
|
|
2968482199 | ||
|
|
f8ac382072 | ||
|
|
aef43910b1 | ||
|
|
87efd4ab84 | ||
|
|
a8b600845e | ||
|
|
fcd9fd8513 | ||
|
|
ffe73f0124 | ||
|
|
0c57250d87 | ||
|
|
f7e012d216 | ||
|
|
c9e3c8b38d | ||
|
|
908a7b6c3d | ||
|
|
cfd7e8a829 | ||
|
|
804b818c6b | ||
|
|
9b9d14c2c4 | ||
|
|
38fc8eeaba | ||
|
|
e70221a9f1 | ||
|
|
126202648f | ||
|
|
dc8475995f | ||
|
|
3ca1373274 | ||
|
|
4aaf07d62a | ||
|
|
ff10a4603f | ||
|
|
53eb56bb1e | ||
|
|
c6209d76eb | ||
|
|
99dc8c7871 | ||
|
|
f588ccff72 | ||
|
|
69746f2f0b | ||
|
|
65da9425df | ||
|
|
b7583e95a5 | ||
|
|
9437a1a844 | ||
|
|
435564f0f2 | ||
|
|
2a6e522a87 | ||
|
|
9c1db7dca7 | ||
|
|
cd7cb19aee | ||
|
|
d84fa4d154 | ||
|
|
d574706600 | ||
|
|
8369e59b4d | ||
|
|
5be8fbab56 | ||
|
|
6101733232 | ||
|
|
778861f461 | ||
|
|
6c9d6a4d57 | ||
|
|
9962118dbd | ||
|
|
a4b2c10fb8 | ||
|
|
2c17bb2c36 | ||
|
|
da91217bc9 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -183,6 +183,7 @@ docker/nginx/conf.d/default.conf
|
||||
docker/nginx/ssl/*
|
||||
!docker/nginx/ssl/.gitkeep
|
||||
docker/middleware.env
|
||||
docker/docker-compose.override.yaml
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
153
CONTRIBUTING_TW.md
Normal file
153
CONTRIBUTING_TW.md
Normal file
@@ -0,0 +1,153 @@
|
||||
# 貢獻指南
|
||||
|
||||
您想為 Dify 做出貢獻 - 這太棒了,我們迫不及待地想看看您的成果。作為一家人力和資金有限的初創公司,我們有宏大的抱負,希望設計出最直觀的工作流程來構建和管理 LLM 應用程式。來自社群的任何幫助都非常珍貴,真的。
|
||||
|
||||
鑑於我們的現狀,我們需要靈活且快速地發展,但同時也希望確保像您這樣的貢獻者能夠獲得盡可能順暢的貢獻體驗。我們編寫了這份貢獻指南,目的是幫助您熟悉代碼庫以及我們如何與貢獻者合作,讓您可以更快地進入有趣的部分。
|
||||
|
||||
這份指南,就像 Dify 本身一樣,是不斷發展的。如果有時它落後於實際項目,我們非常感謝您的理解,也歡迎任何改進的反饋。
|
||||
|
||||
關於授權,請花一分鐘閱讀我們簡短的[授權和貢獻者協議](./LICENSE)。社群也遵守[行為準則](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
|
||||
|
||||
## 在開始之前
|
||||
|
||||
[尋找](https://github.com/langgenius/dify/issues?q=is:issue+is:open)現有的 issue,或[創建](https://github.com/langgenius/dify/issues/new/choose)一個新的。我們將 issues 分為 2 種類型:
|
||||
|
||||
### 功能請求
|
||||
|
||||
- 如果您要開啟新的功能請求,我們希望您能解釋所提議的功能要達成什麼目標,並且盡可能包含更多的相關背景資訊。[@perzeusss](https://github.com/perzeuss) 已經製作了一個實用的[功能請求輔助工具](https://udify.app/chat/MK2kVSnw1gakVwMX),能幫助您草擬您的需求。歡迎試用。
|
||||
|
||||
- 如果您想從現有問題中選擇一個來處理,只需在其下方留言表示即可。
|
||||
|
||||
相關方向的團隊成員會加入討論。如果一切順利,他們會同意您開始編寫代碼。我們要求您在得到許可前先不要開始處理該功能,以免我們提出變更時您的工作成果被浪費。
|
||||
|
||||
根據所提議功能的領域不同,您可能會與不同的團隊成員討論。以下是目前每位團隊成員所負責的領域概述:
|
||||
|
||||
| 成員 | 負責領域 |
|
||||
| --------------------------------------------------------------------------------------- | ------------------------------ |
|
||||
| [@yeuoly](https://github.com/Yeuoly) | 設計 Agents 架構 |
|
||||
| [@jyong](https://github.com/JohnJyong) | RAG 管道設計 |
|
||||
| [@GarfieldDai](https://github.com/GarfieldDai) | 建構工作流程編排 |
|
||||
| [@iamjoel](https://github.com/iamjoel) & [@zxhlyh](https://github.com/zxhlyh) | 打造易用的前端界面 |
|
||||
| [@guchenhe](https://github.com/guchenhe) & [@crazywoola](https://github.com/crazywoola) | 開發者體驗,各類問題的聯絡窗口 |
|
||||
| [@takatost](https://github.com/takatost) | 整體產品方向與架構 |
|
||||
|
||||
我們如何排定優先順序:
|
||||
|
||||
| 功能類型 | 優先級 |
|
||||
| ------------------------------------------------------------------------------------------------------- | -------- |
|
||||
| 被團隊成員標記為高優先級的功能 | 高優先級 |
|
||||
| 來自我們[社群回饋版](https://github.com/langgenius/dify/discussions/categories/feedbacks)的熱門功能請求 | 中優先級 |
|
||||
| 非核心功能和次要增強 | 低優先級 |
|
||||
| 有價值但非急迫的功能 | 未來功能 |
|
||||
|
||||
### 其他事項 (例如錯誤回報、效能優化、錯字更正)
|
||||
|
||||
- 可以直接開始編寫程式碼。
|
||||
|
||||
我們如何排定優先順序:
|
||||
|
||||
| 問題類型 | 優先級 |
|
||||
| ----------------------------------------------------- | -------- |
|
||||
| 核心功能的錯誤 (無法登入、應用程式無法運行、安全漏洞) | 重要 |
|
||||
| 非關鍵性錯誤、效能提升 | 中優先級 |
|
||||
| 小修正 (錯字、令人困惑但仍可運作的使用者界面) | 低優先級 |
|
||||
|
||||
## 安裝
|
||||
|
||||
以下是設置 Dify 開發環境的步驟:
|
||||
|
||||
### 1. 分叉此存儲庫
|
||||
|
||||
### 2. 複製代碼庫
|
||||
|
||||
從您的終端機複製分叉的代碼庫:
|
||||
|
||||
```shell
|
||||
git clone git@github.com:<github_username>/dify.git
|
||||
```
|
||||
|
||||
- [Docker](https://www.docker.com/)
|
||||
- [Docker Compose](https://docs.docker.com/compose/install/)
|
||||
- [Node.js v18.x (LTS)](http://nodejs.org)
|
||||
- [pnpm](https://pnpm.io/)
|
||||
- [Python](https://www.python.org/) version 3.11.x or 3.12.x
|
||||
|
||||
### 4. 安裝
|
||||
|
||||
Dify 由後端和前端組成。透過 `cd api/` 導航至後端目錄,然後按照[後端 README](api/README.md)進行安裝。在另一個終端機視窗中,透過 `cd web/` 導航至前端目錄,然後按照[前端 README](web/README.md)進行安裝。
|
||||
|
||||
查閱[安裝常見問題](https://docs.dify.ai/learn-more/faq/install-faq)了解常見問題和故障排除步驟的列表。
|
||||
|
||||
### 5. 在瀏覽器中訪問 Dify
|
||||
|
||||
要驗證您的設置,請在瀏覽器中訪問 [http://localhost:3000](http://localhost:3000)(預設值,或您自行設定的 URL 和埠號)。現在您應該能看到 Dify 已啟動並運行。
|
||||
|
||||
## 開發
|
||||
|
||||
如果您要添加模型提供者,請參考[此指南](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/README.md)。
|
||||
|
||||
如果您要為 Agent 或工作流程添加工具提供者,請參考[此指南](./api/core/tools/README.md)。
|
||||
|
||||
為了幫助您快速找到您的貢獻適合的位置,以下是 Dify 後端和前端的簡要註解大綱:
|
||||
|
||||
### 後端
|
||||
|
||||
Dify 的後端使用 Python 的 [Flask](https://flask.palletsprojects.com/en/3.0.x/) 框架編寫。它使用 [SQLAlchemy](https://www.sqlalchemy.org/) 作為 ORM 工具,使用 [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) 進行任務佇列處理。授權邏輯則透過 Flask-login 實現。
|
||||
|
||||
```text
|
||||
[api/]
|
||||
├── constants // 整個專案中使用的常數與設定值
|
||||
├── controllers // API 路由定義與請求處理邏輯
|
||||
├── core // 核心應用服務、模型整合與工具實現
|
||||
├── docker // Docker 容器化相關設定檔案
|
||||
├── events // 事件處理與流程管理機制
|
||||
├── extensions // 與第三方框架或平台的整合擴充功能
|
||||
├── fields // 資料序列化與結構定義欄位
|
||||
├── libs // 可重複使用的共用程式庫與輔助工具
|
||||
├── migrations // 資料庫結構變更與遷移腳本
|
||||
├── models // 資料庫模型與資料結構定義
|
||||
├── services // 核心業務邏輯與功能實現
|
||||
├── storage // 私鑰與敏感資訊儲存機制
|
||||
├── tasks // 非同步任務與背景作業處理器
|
||||
└── tests
|
||||
```
|
||||
|
||||
### 前端
|
||||
|
||||
網站基於 [Next.js](https://nextjs.org/) 的 Typescript 樣板,並使用 [Tailwind CSS](https://tailwindcss.com/) 進行樣式設計。[React-i18next](https://react.i18next.com/) 用於國際化。
|
||||
|
||||
```text
|
||||
[web/]
|
||||
├── app // 頁面佈局與介面元件
|
||||
│ ├── (commonLayout) // 應用程式共用佈局結構
|
||||
│ ├── (shareLayout) // Token 會話專用共享佈局
|
||||
│ ├── activate // 帳號啟用頁面
|
||||
│ ├── components // 頁面與佈局共用元件
|
||||
│ ├── install // 系統安裝頁面
|
||||
│ ├── signin // 使用者登入頁面
|
||||
│ └── styles // 全域共用樣式定義
|
||||
├── assets // 靜態資源檔案庫
|
||||
├── bin // 建構流程執行腳本
|
||||
├── config // 系統可調整設定與選項
|
||||
├── context // 應用程式狀態共享上下文
|
||||
├── dictionaries // 多語系翻譯詞彙庫
|
||||
├── docker // Docker 容器設定檔
|
||||
├── hooks // 可重複使用的 React Hooks
|
||||
├── i18n // 國際化與本地化設定
|
||||
├── models // 資料結構與 API 回應模型
|
||||
├── public // 靜態資源與網站圖標
|
||||
├── service // API 操作介面定義
|
||||
├── test // 測試用例與測試框架
|
||||
├── types // TypeScript 型別定義
|
||||
└── utils // 共用輔助功能函式庫
|
||||
```
|
||||
|
||||
## 提交您的 PR
|
||||
|
||||
最後,是時候向我們的存儲庫開啟拉取請求(PR)了。對於主要功能,我們會先將它們合併到 `deploy/dev` 分支進行測試,然後才會進入 `main` 分支。如果您遇到合併衝突或不知道如何開啟拉取請求等問題,請查看 [GitHub 的拉取請求教學](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests)。
|
||||
|
||||
就是這樣!一旦您的 PR 被合併,您將作為貢獻者出現在我們的 [README](https://github.com/langgenius/dify/blob/main/README.md) 中。
|
||||
|
||||
## 獲取幫助
|
||||
|
||||
如果您在貢獻過程中遇到困難或有迫切的問題,只需通過相關的 GitHub issue 向我們提問,或加入我們的 [Discord](https://discord.gg/8Tpq4AcN9c) 進行快速交流。
|
||||
74
README.md
74
README.md
@@ -40,6 +40,7 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
@@ -53,14 +54,14 @@
|
||||
<a href="./README_BN.md"><img alt="README in বাংলা" src="https://img.shields.io/badge/বাংলা-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
|
||||
Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
|
||||
|
||||
## Quick start
|
||||
|
||||
> Before installing Dify, make sure your machine meets the following minimum system requirements:
|
||||
>
|
||||
>- CPU >= 2 Core
|
||||
>- RAM >= 4 GiB
|
||||
>
|
||||
> - CPU >= 2 Core
|
||||
> - RAM >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
@@ -76,41 +77,40 @@ docker compose up -d
|
||||
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
|
||||
|
||||
#### Seeking help
|
||||
|
||||
Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) if you encounter problems setting up Dify. Reach out to [the community and us](#community--contact) if you are still having issues.
|
||||
|
||||
> If you'd like to contribute to Dify or do additional development, refer to our [guide to deploying from source code](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
## Key features
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
**1. Workflow**:
|
||||
Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond.
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
|
||||
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
**2. Comprehensive model support**:
|
||||
Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers).
|
||||
|
||||

|
||||
|
||||
**3. Prompt IDE**:
|
||||
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
|
||||
|
||||
**3. Prompt IDE**:
|
||||
Intuitive interface for crafting prompts, comparing model performance, and adding additional features such as text-to-speech to a chat-based app.
|
||||
**4. RAG Pipeline**:
|
||||
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
|
||||
|
||||
**4. RAG Pipeline**:
|
||||
Extensive RAG capabilities that cover everything from document ingestion to retrieval, with out-of-box support for text extraction from PDFs, PPTs, and other common document formats.
|
||||
**5. Agent capabilities**:
|
||||
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
|
||||
|
||||
**5. Agent capabilities**:
|
||||
You can define agents based on LLM Function Calling or ReAct, and add pre-built or custom tools for the agent. Dify provides 50+ built-in tools for AI agents, such as Google Search, DALL·E, Stable Diffusion and WolframAlpha.
|
||||
**6. LLMOps**:
|
||||
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
|
||||
|
||||
**6. LLMOps**:
|
||||
Monitor and analyze application logs and performance over time. You could continuously improve prompts, datasets, and models based on production data and annotations.
|
||||
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
**7. Backend-as-a-Service**:
|
||||
All of Dify's offerings come with corresponding APIs, so you could effortlessly integrate Dify into your own business logic.
|
||||
|
||||
## Feature Comparison
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">Feature</th>
|
||||
@@ -180,24 +180,22 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host
|
||||
## Using Dify
|
||||
|
||||
- **Cloud </br>**
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
We host a [Dify Cloud](https://dify.ai) service for anyone to try with zero setup. It provides all the capabilities of the self-deployed version, and includes 200 free GPT-4 calls in the sandbox plan.
|
||||
|
||||
- **Self-hosting Dify Community Edition</br>**
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
Quickly get Dify running in your environment with this [starter guide](#quick-start).
|
||||
Use our [documentation](https://docs.dify.ai) for further references and more in-depth instructions.
|
||||
|
||||
- **Dify for enterprise / organizations</br>**
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs. </br>
|
||||
> For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
|
||||
|
||||
|
||||
## Staying ahead
|
||||
|
||||
Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||

|
||||
|
||||
|
||||
## Advanced Setup
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
@@ -213,32 +211,34 @@ If you'd like to configure a highly-available setup, there are community-contrib
|
||||
Deploy Dify to Cloud Platform with a single click using [terraform](https://www.terraform.io/)
|
||||
|
||||
##### Azure Global
|
||||
|
||||
- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
##### Google Cloud
|
||||
|
||||
- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
#### Using AWS CDK for Deployment
|
||||
|
||||
Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
|
||||
|
||||
##### AWS
|
||||
##### AWS
|
||||
|
||||
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## Contributing
|
||||
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
|
||||
|
||||
|
||||
> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
|
||||
|
||||
## Community & contact
|
||||
|
||||
* [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
* [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
* [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
* [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions.
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
|
||||
- [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community.
|
||||
- [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community.
|
||||
|
||||
**Contributors**
|
||||
|
||||
@@ -250,7 +250,6 @@ At the same time, please consider supporting Dify by sharing it on social media
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
|
||||
## Security disclosure
|
||||
|
||||
To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
|
||||
@@ -258,4 +257,3 @@ To protect your privacy, please avoid posting security issues on GitHub. Instead
|
||||
## License
|
||||
|
||||
This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions.
|
||||
|
||||
|
||||
258
README_TW.md
Normal file
258
README_TW.md
Normal file
@@ -0,0 +1,258 @@
|
||||

|
||||
|
||||
<p align="center">
|
||||
📌 <a href="https://dify.ai/blog/introducing-dify-workflow-file-upload-a-demo-on-ai-podcast">介紹 Dify 工作流程檔案上傳功能:重現 Google NotebookLM Podcast</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://cloud.dify.ai">Dify 雲端服務</a> ·
|
||||
<a href="https://docs.dify.ai/getting-started/install-self-hosted">自行託管</a> ·
|
||||
<a href="https://docs.dify.ai">說明文件</a> ·
|
||||
<a href="https://udify.app/chat/22L1zSxg6yW1cWQg">企業諮詢</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Product-F04438"></a>
|
||||
<a href="https://dify.ai/pricing" target="_blank">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/free-pricing?logo=free&color=%20%23155EEF&label=pricing&labelColor=%20%23528bff"></a>
|
||||
<a href="https://discord.gg/FngNHpbcY7" target="_blank">
|
||||
<img src="https://img.shields.io/discord/1082486657678311454?logo=discord&labelColor=%20%235462eb&logoColor=%20%23f5f5f5&color=%20%235462eb"
|
||||
alt="chat on Discord"></a>
|
||||
<a href="https://reddit.com/r/difyai" target="_blank">
|
||||
<img src="https://img.shields.io/reddit/subreddit-subscribers/difyai?style=plastic&logo=reddit&label=r%2Fdifyai&labelColor=white"
|
||||
alt="join Reddit"></a>
|
||||
<a href="https://twitter.com/intent/follow?screen_name=dify_ai" target="_blank">
|
||||
<img src="https://img.shields.io/twitter/follow/dify_ai?logo=X&color=%20%23f5f5f5"
|
||||
alt="follow on X(Twitter)"></a>
|
||||
<a href="https://www.linkedin.com/company/langgenius/" target="_blank">
|
||||
<img src="https://custom-icon-badges.demolab.com/badge/LinkedIn-0A66C2?logo=linkedin-white&logoColor=fff"
|
||||
alt="follow on LinkedIn"></a>
|
||||
<a href="https://hub.docker.com/u/langgenius" target="_blank">
|
||||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web?labelColor=%20%23FDB062&color=%20%23f79009"></a>
|
||||
<a href="https://github.com/langgenius/dify/graphs/commit-activity" target="_blank">
|
||||
<img alt="Commits last month" src="https://img.shields.io/github/commit-activity/m/langgenius/dify?labelColor=%20%2332b583&color=%20%2312b76a"></a>
|
||||
<a href="https://github.com/langgenius/dify/" target="_blank">
|
||||
<img alt="Issues closed" src="https://img.shields.io/github/issues-search?query=repo%3Alanggenius%2Fdify%20is%3Aclosed&label=issues%20closed&labelColor=%20%237d89b0&color=%20%235d6b98"></a>
|
||||
<a href="https://github.com/langgenius/dify/discussions/" target="_blank">
|
||||
<img alt="Discussion posts" src="https://img.shields.io/github/discussions/langgenius/dify?labelColor=%20%239b8afb&color=%20%237a5af8"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="./README.md"><img alt="README in English" src="https://img.shields.io/badge/English-d9d9d9"></a>
|
||||
<a href="./README_TW.md"><img alt="繁體中文文件" src="https://img.shields.io/badge/繁體中文-d9d9d9"></a>
|
||||
<a href="./README_CN.md"><img alt="简体中文版自述文件" src="https://img.shields.io/badge/简体中文-d9d9d9"></a>
|
||||
<a href="./README_JA.md"><img alt="日本語のREADME" src="https://img.shields.io/badge/日本語-d9d9d9"></a>
|
||||
<a href="./README_ES.md"><img alt="README en Español" src="https://img.shields.io/badge/Español-d9d9d9"></a>
|
||||
<a href="./README_FR.md"><img alt="README en Français" src="https://img.shields.io/badge/Français-d9d9d9"></a>
|
||||
<a href="./README_KL.md"><img alt="README tlhIngan Hol" src="https://img.shields.io/badge/Klingon-d9d9d9"></a>
|
||||
<a href="./README_KR.md"><img alt="README in Korean" src="https://img.shields.io/badge/한국어-d9d9d9"></a>
|
||||
<a href="./README_AR.md"><img alt="README بالعربية" src="https://img.shields.io/badge/العربية-d9d9d9"></a>
|
||||
<a href="./README_TR.md"><img alt="Türkçe README" src="https://img.shields.io/badge/Türkçe-d9d9d9"></a>
|
||||
<a href="./README_VI.md"><img alt="README Tiếng Việt" src="https://img.shields.io/badge/Ti%E1%BA%BFng%20Vi%E1%BB%87t-d9d9d9"></a>
|
||||
<a href="./README_DE.md"><img alt="README in Deutsch" src="https://img.shields.io/badge/German-d9d9d9"></a>
|
||||
</p>
|
||||
|
||||
Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合了智能代理工作流程、RAG 管道、代理功能、模型管理、可觀察性功能等,讓您能夠快速從原型進展到生產環境。
|
||||
|
||||
## 快速開始
|
||||
|
||||
> 安裝 Dify 之前,請確保您的機器符合以下最低系統要求:
|
||||
>
|
||||
> - CPU >= 2 核心
|
||||
> - 記憶體 >= 4 GiB
|
||||
|
||||
</br>
|
||||
|
||||
啟動 Dify 伺服器最簡單的方式是透過 [docker compose](docker/docker-compose.yaml)。在使用以下命令運行 Dify 之前,請確保您的機器已安裝 [Docker](https://docs.docker.com/get-docker/) 和 [Docker Compose](https://docs.docker.com/compose/install/):
|
||||
|
||||
```bash
|
||||
cd dify
|
||||
cd docker
|
||||
cp .env.example .env
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
運行後,您可以在瀏覽器中通過 [http://localhost/install](http://localhost/install) 訪問 Dify 儀表板並開始初始化過程。
|
||||
|
||||
### 尋求幫助
|
||||
|
||||
如果您在設置 Dify 時遇到問題,請參考我們的 [常見問題](https://docs.dify.ai/getting-started/install-self-hosted/faqs)。如果仍有疑問,請聯絡 [社區和我們](#community--contact)。
|
||||
|
||||
> 如果您想為 Dify 做出貢獻或進行額外開發,請參考我們的 [從原始碼部署指南](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
|
||||
|
||||
## 核心功能
|
||||
|
||||
**1. 工作流程**:
|
||||
在視覺化畫布上建立和測試強大的 AI 工作流程,利用以下所有功能及更多。
|
||||
|
||||
https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
|
||||
|
||||
**2. 全面的模型支援**:
|
||||
無縫整合來自數十個推理提供商和自託管解決方案的數百個專有/開源 LLM,涵蓋 GPT、Mistral、Llama3 和任何與 OpenAI API 兼容的模型。您可以在[此處](https://docs.dify.ai/getting-started/readme/model-providers)找到支援的模型提供商完整列表。
|
||||
|
||||

|
||||
|
||||
**3. 提示詞 IDE**:
|
||||
直觀的界面,用於編寫提示詞、比較模型性能,以及為聊天型應用程式添加文字轉語音等額外功能。
|
||||
|
||||
**4. RAG 管道**:
|
||||
廣泛的 RAG 功能,涵蓋從文件擷取到檢索的全部流程,內建支援從 PDF、PPT 和其他常見文件格式提取文本。
|
||||
|
||||
**5. 代理功能**:
|
||||
您可以基於 LLM 函數調用或 ReAct 定義代理,並為代理添加預構建或自定義工具。Dify 為 AI 代理提供 50 多種內建工具,如 Google 搜尋、DALL·E、Stable Diffusion 和 WolframAlpha。
|
||||
|
||||
**6. LLMOps**:
|
||||
監控並分析應用程式日誌和長期效能。您可以根據生產數據和標註持續改進提示詞、數據集和模型。
|
||||
|
||||
**7. 後端即服務**:
|
||||
Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 整合到您自己的業務邏輯中。
|
||||
|
||||
## 功能比較
|
||||
|
||||
<table style="width: 100%;">
|
||||
<tr>
|
||||
<th align="center">功能</th>
|
||||
<th align="center">Dify.AI</th>
|
||||
<th align="center">LangChain</th>
|
||||
<th align="center">Flowise</th>
|
||||
<th align="center">OpenAI Assistants API</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">程式設計方法</td>
|
||||
<td align="center">API + 應用導向</td>
|
||||
<td align="center">Python 代碼</td>
|
||||
<td align="center">應用導向</td>
|
||||
<td align="center">API 導向</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">支援的 LLM 模型</td>
|
||||
<td align="center">豐富多樣</td>
|
||||
<td align="center">豐富多樣</td>
|
||||
<td align="center">豐富多樣</td>
|
||||
<td align="center">僅限 OpenAI</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">RAG 引擎</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">代理功能</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">工作流程</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">可觀察性</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">企業級功能 (SSO/存取控制)</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">本地部署</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">✅</td>
|
||||
<td align="center">❌</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 使用 Dify
|
||||
|
||||
- **雲端服務 </br>**
|
||||
我們提供 [Dify Cloud](https://dify.ai) 服務,任何人都可以零配置嘗試。它提供與自部署版本相同的所有功能,並在沙盒計劃中包含 200 次免費 GPT-4 調用。
|
||||
|
||||
- **自託管 Dify 社區版</br>**
|
||||
使用這份[快速指南](#快速開始)在您的環境中快速運行 Dify。
|
||||
使用我們的[文檔](https://docs.dify.ai)獲取更多參考和深入指導。
|
||||
|
||||
- **企業/組織版 Dify</br>**
|
||||
我們提供額外的企業中心功能。[通過這個聊天機器人記錄您的問題](https://udify.app/chat/22L1zSxg6yW1cWQg)或[發送電子郵件給我們](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry)討論企業需求。</br>
|
||||
> 對於使用 AWS 的初創企業和小型企業,請查看 [AWS Marketplace 上的 Dify Premium](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6),並一鍵部署到您自己的 AWS VPC。這是一個經濟實惠的 AMI 產品,可選擇使用自定義徽標和品牌創建應用。
|
||||
|
||||
## 保持領先
|
||||
|
||||
在 GitHub 上為 Dify 加星,即時獲取新版本通知。
|
||||
|
||||

|
||||
|
||||
## 進階設定
|
||||
|
||||
如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。
|
||||
|
||||
如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。
|
||||
|
||||
- [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify)
|
||||
- [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm)
|
||||
- [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes)
|
||||
|
||||
### 使用 Terraform 進行部署
|
||||
|
||||
使用 [terraform](https://www.terraform.io/) 一鍵部署 Dify 到雲端平台
|
||||
|
||||
### Azure 全球
|
||||
|
||||
- [由 @nikawang 提供的 Azure Terraform](https://github.com/nikawang/dify-azure-terraform)
|
||||
|
||||
### Google Cloud
|
||||
|
||||
- [由 @sotazum 提供的 Google Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform)
|
||||
|
||||
### 使用 AWS CDK 進行部署
|
||||
|
||||
使用 [CDK](https://aws.amazon.com/cdk/) 部署 Dify 到 AWS
|
||||
|
||||
### AWS
|
||||
|
||||
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
|
||||
|
||||
## 貢獻
|
||||
|
||||
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
同時,也請考慮透過在社群媒體和各種活動與會議上分享 Dify 來支持我們。
|
||||
|
||||
> 我們正在尋找貢獻者協助將 Dify 翻譯成中文和英文以外的語言。如果您有興趣幫忙,請查看 [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) 獲取更多資訊,並在我們的 [Discord 社群伺服器](https://discord.gg/8Tpq4AcN9c) 的 `global-users` 頻道留言給我們。
|
||||
|
||||
## 社群與聯絡方式
|
||||
|
||||
- [Github Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。
|
||||
- [GitHub Issues](https://github.com/langgenius/dify/issues):最適合報告使用 Dify.AI 時遇到的問題和提出功能建議。請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
|
||||
- [Discord](https://discord.gg/FngNHpbcY7):最適合分享您的應用程式並與社群互動。
|
||||
- [X(Twitter)](https://twitter.com/dify_ai):最適合分享您的應用程式並與社群互動。
|
||||
|
||||
**貢獻者**
|
||||
|
||||
<a href="https://github.com/langgenius/dify/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=langgenius/dify" />
|
||||
</a>
|
||||
|
||||
## 星星歷史
|
||||
|
||||
[](https://star-history.com/#langgenius/dify&Date)
|
||||
|
||||
## 安全揭露
|
||||
|
||||
為保護您的隱私,請避免在 GitHub 上發布安全性問題。請將您的問題發送至 security@dify.ai,我們將為您提供更詳細的答覆。
|
||||
|
||||
## 授權條款
|
||||
|
||||
本代碼庫採用 [Dify 開源授權](LICENSE),這基本上是 Apache 2.0 授權加上一些額外限制條款。
|
||||
@@ -71,7 +71,7 @@ from .app import (
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
from .billing import billing, compliance
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import (
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
@@ -11,8 +12,11 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.auto.workflow_generator.workflow_generator import WorkflowGenerator
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
|
||||
@@ -85,5 +89,45 @@ class RuleCodeGenerateApi(Resource):
|
||||
return code_result
|
||||
|
||||
|
||||
class AutoGenerateWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
"""
|
||||
Auto generate workflow
|
||||
"""
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("instruction", type=str, required=True, location="json")
|
||||
parser.add_argument("model_config", type=dict, required=True, location="json")
|
||||
tenant_id = cast(str, current_user.current_tenant_id)
|
||||
args = parser.parse_args()
|
||||
instruction = args.get("instruction")
|
||||
if not instruction:
|
||||
raise ValueError("Instruction is required")
|
||||
if not args.get("model_config"):
|
||||
raise ValueError("Model config is required")
|
||||
model_config = cast(dict, args.get("model_config"))
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
workflow_generator = WorkflowGenerator(
|
||||
model_instance=model_instance,
|
||||
)
|
||||
workflow_yaml = workflow_generator.generate_workflow(
|
||||
user_requirement=instruction,
|
||||
)
|
||||
return workflow_yaml
|
||||
|
||||
|
||||
api.add_resource(RuleGenerateApi, "/rule-generate")
|
||||
api.add_resource(RuleCodeGenerateApi, "/rule-code-generate")
|
||||
api.add_resource(
|
||||
AutoGenerateWorkflowApi,
|
||||
"/auto-generate",
|
||||
)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -13,6 +15,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
@@ -24,7 +27,7 @@ from models.account import Account
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.workflow_service import WorkflowService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -246,6 +249,80 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate_single_loop(
|
||||
app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logging.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -365,10 +442,38 @@ class PublishedWorkflowApi(Resource):
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, default="", location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, default="", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
return {"result": "success", "created_at": TimestampField().format(workflow.created_at)}
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
marked_name=args.marked_name or "",
|
||||
marked_comment=args.marked_comment or "",
|
||||
)
|
||||
|
||||
app_model.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
class DefaultBlockConfigsApi(Resource):
|
||||
@@ -490,32 +595,193 @@ class PublishedAllWorkflowApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
parser.add_argument("user_id", type=str, required=False, location="args")
|
||||
parser.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
page = args.get("page")
|
||||
limit = args.get("limit")
|
||||
page = int(args.get("page", 1))
|
||||
limit = int(args.get("limit", 10))
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(app_model=app_model, page=page, limit=limit)
|
||||
with Session(db.engine) as session:
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
page=page,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
named_only=named_only,
|
||||
)
|
||||
|
||||
return {"items": workflows, "page": page, "limit": limit, "has_more": has_more}
|
||||
return {
|
||||
"items": workflows,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"has_more": has_more,
|
||||
}
|
||||
|
||||
|
||||
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
|
||||
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
|
||||
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
|
||||
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
|
||||
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
api.add_resource(DraftWorkflowNodeRunApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_fields)
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("marked_name", type=str, required=False, location="json")
|
||||
parser.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = workflow_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Delete workflow
|
||||
"""
|
||||
# Check permission
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
abort(400, description=str(e))
|
||||
except ValueError as e:
|
||||
raise NotFound(str(e))
|
||||
|
||||
return None, 204
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DraftWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/config",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowTaskStopApi,
|
||||
"/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop",
|
||||
)
|
||||
api.add_resource(
|
||||
DraftWorkflowNodeRunApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
AdvancedChatDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunIterationNodeApi, "/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run"
|
||||
WorkflowDraftRunIterationNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/iteration/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(PublishedWorkflowApi, "/apps/<uuid:app_id>/workflows/publish")
|
||||
api.add_resource(PublishedAllWorkflowApi, "/apps/<uuid:app_id>/workflows")
|
||||
api.add_resource(DefaultBlockConfigsApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs")
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi, "/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>"
|
||||
AdvancedChatDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/advanced-chat/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowDraftRunLoopNodeApi,
|
||||
"/apps/<uuid:app_id>/workflows/draft/loop/nodes/<string:node_id>/run",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows/publish",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishedAllWorkflowApi,
|
||||
"/apps/<uuid:app_id>/workflows",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigsApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs",
|
||||
)
|
||||
api.add_resource(
|
||||
DefaultBlockConfigApi,
|
||||
"/apps/<uuid:app_id>/workflows/default-workflow-block-configs/<string:block_type>",
|
||||
)
|
||||
api.add_resource(
|
||||
ConvertToWorkflowApi,
|
||||
"/apps/<uuid:app_id>/convert-to-workflow",
|
||||
)
|
||||
api.add_resource(
|
||||
WorkflowByIdApi,
|
||||
"/apps/<uuid:app_id>/workflows/<string:workflow_id>",
|
||||
)
|
||||
api.add_resource(ConvertToWorkflowApi, "/apps/<uuid:app_id>/convert-to-workflow")
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunStatus
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
|
||||
@@ -24,17 +29,38 @@ class WorkflowAppLogApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument(
|
||||
"created_at__before", type=str, location="args", help="Filter logs created before this timestamp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
|
||||
)
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowAppLogApi, "/apps/<uuid:app_id>/workflow-app-logs")
|
||||
|
||||
35
api/controllers/console/billing/compliance.py
Normal file
35
api/controllers/console/billing/compliance.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from flask import request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
from .. import api
|
||||
from ..wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
|
||||
|
||||
class ComplianceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("doc_name", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
return BillingService.get_compliance_download_link(
|
||||
doc_name=args.doc_name,
|
||||
account_id=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
ip=ip_address,
|
||||
device_info=device_info,
|
||||
)
|
||||
|
||||
|
||||
api.add_resource(ComplianceApi, "/compliance/download")
|
||||
@@ -101,3 +101,9 @@ class AccountInFreezeError(BaseHTTPException):
|
||||
"This email account has been deleted within the past 30 days"
|
||||
"and is temporarily unavailable for new account registration."
|
||||
)
|
||||
|
||||
|
||||
class CompilanceRateLimitError(BaseHTTPException):
|
||||
error_code = "compilance_rate_limit"
|
||||
description = "Rate limit exceeded for downloading compliance report."
|
||||
code = 429
|
||||
|
||||
@@ -70,7 +70,7 @@ class MessageListApi(Resource):
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"], "desc"
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
|
||||
from flask_restful.inputs import int_range # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.service_api import api
|
||||
@@ -25,7 +27,7 @@ from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||
from libs import helper
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowRun, WorkflowRunStatus
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
@@ -125,17 +127,34 @@ class WorkflowAppLogApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
parser.add_argument("created_at__before", type=str, location="args")
|
||||
parser.add_argument("created_at__after", type=str, location="args")
|
||||
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.status = WorkflowRunStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = datetime.fromisoformat(args.created_at__before.replace("Z", "+00:00"))
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = datetime.fromisoformat(args.created_at__after.replace("Z", "+00:00"))
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
app_model=app_model, args=args
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
return workflow_app_log_pagination
|
||||
return workflow_app_log_pagination
|
||||
|
||||
|
||||
api.add_resource(WorkflowRunApi, "/workflows/run")
|
||||
|
||||
@@ -223,6 +223,61 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -79,6 +79,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
|
||||
@@ -23,10 +23,14 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@@ -372,7 +376,13 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
|
||||
if node_finish_resp:
|
||||
yield node_finish_resp
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session, event=event
|
||||
@@ -472,6 +482,54 @@ class AdvancedChatAppGenerateTaskPipeline:
|
||||
)
|
||||
|
||||
yield iter_finish_resp
|
||||
elif isinstance(event, QueueLoopStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_start_resp
|
||||
elif isinstance(event, QueueLoopNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_next_resp
|
||||
elif isinstance(event, QueueLoopCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_finish_resp
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
@@ -250,6 +250,60 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
|
||||
@@ -81,6 +81,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
@@ -18,9 +18,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@@ -323,7 +327,13 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
if node_success_response:
|
||||
yield node_success_response
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed(
|
||||
session=session,
|
||||
@@ -429,6 +439,57 @@ class WorkflowAppGenerateTaskPipeline:
|
||||
|
||||
yield iter_finish_resp
|
||||
|
||||
elif isinstance(event, QueueLoopStartEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_start_resp
|
||||
|
||||
elif isinstance(event, QueueLoopNextEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_next_resp
|
||||
|
||||
elif isinstance(event, QueueLoopCompletedEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow_run = self._workflow_cycle_manager._get_workflow_run(
|
||||
session=session, workflow_run_id=self._workflow_run_id
|
||||
)
|
||||
loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response(
|
||||
session=session,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run,
|
||||
event=event,
|
||||
)
|
||||
|
||||
yield loop_finish_resp
|
||||
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not self._workflow_run_id:
|
||||
raise ValueError("workflow run not initialized.")
|
||||
|
||||
@@ -9,9 +9,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@@ -38,7 +42,12 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeInLoopFailedEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
@@ -173,6 +182,96 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_loop(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single loop
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in loop
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("loop_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in loop
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
loop_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
loop_node_config = node
|
||||
break
|
||||
|
||||
if not loop_node_config:
|
||||
raise ValueError("loop node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(loop_node_config.get("data", {}).get("type"))
|
||||
node_version = loop_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=loop_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
@@ -216,6 +315,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
@@ -240,6 +340,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
)
|
||||
@@ -272,6 +373,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@@ -302,6 +404,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@@ -332,6 +435,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
@@ -362,18 +466,49 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInLoopFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInLoopFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs or {}
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, AgentLogEvent):
|
||||
@@ -387,6 +522,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
@@ -397,6 +533,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
@@ -407,6 +544,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
@@ -417,6 +555,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
@@ -476,6 +615,62 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueLoopStartEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueLoopNextEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_loop_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (LoopRunSucceededEvent | LoopRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueLoopCompletedEvent(
|
||||
node_execution_id=event.loop_id,
|
||||
node_id=event.loop_node_id,
|
||||
node_type=event.loop_node_type,
|
||||
node_data=event.loop_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, LoopRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@@ -187,6 +187,16 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
@@ -206,3 +216,13 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
@@ -30,6 +30,9 @@ class QueueEvent(StrEnum):
|
||||
ITERATION_START = "iteration_start"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
LOOP_START = "loop_start"
|
||||
LOOP_NEXT = "loop_next"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_SUCCEEDED = "node_succeeded"
|
||||
NODE_FAILED = "node_failed"
|
||||
@@ -149,6 +152,89 @@ class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueLoopStartEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLoopStartEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.LOOP_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
|
||||
|
||||
class QueueLoopNextEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLoopNextEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.LOOP_NEXT
|
||||
|
||||
index: int
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current loop
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class QueueLoopCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLoopCompletedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.LOOP_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
@@ -160,6 +246,8 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""from variable selector"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@@ -189,6 +277,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@@ -278,6 +368,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
@@ -305,6 +397,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
@@ -315,6 +409,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
error: Optional[str] = None
|
||||
"""single iteration duration map"""
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
"""single loop duration map"""
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueAgentLogEvent(AppQueueEvent):
|
||||
@@ -331,6 +427,7 @@ class QueueAgentLogEvent(AppQueueEvent):
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
node_id: str
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(QueueNodeStartedEvent):
|
||||
@@ -368,6 +465,41 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeInLoopFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInLoopFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
@@ -399,6 +531,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
@@ -430,6 +564,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
@@ -549,6 +685,8 @@ class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
@@ -566,6 +704,8 @@ class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
@@ -583,4 +723,6 @@ class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
error: str
|
||||
|
||||
@@ -59,6 +59,9 @@ class StreamEvent(Enum):
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
LOOP_STARTED = "loop_started"
|
||||
LOOP_NEXT = "loop_next"
|
||||
LOOP_COMPLETED = "loop_completed"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
@@ -248,6 +251,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
agent_strategy: Optional[AgentNodeStrategyInit] = None
|
||||
|
||||
@@ -275,6 +279,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -310,6 +315,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@@ -342,6 +348,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -377,6 +384,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@@ -410,6 +418,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
@@ -430,6 +439,7 @@ class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
@@ -452,6 +462,7 @@ class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
loop_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
@@ -548,6 +559,93 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class LoopNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class LoopNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
pre_loop_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_NEXT
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeCompletedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[Mapping] = None
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
execution_metadata: Optional[Mapping] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_COMPLETED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
TextChunkStreamResponse entity
|
||||
@@ -719,6 +817,7 @@ class AgentLogStreamResponse(StreamResponse):
|
||||
status: str
|
||||
data: Mapping[str, Any]
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
node_id: str
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_LOG
|
||||
data: Data
|
||||
|
||||
@@ -14,9 +14,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
@@ -29,6 +33,9 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
LoopNodeCompletedStreamResponse,
|
||||
LoopNodeNextStreamResponse,
|
||||
LoopNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
@@ -304,6 +311,7 @@ class WorkflowCycleManage:
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@@ -344,7 +352,10 @@ class WorkflowCycleManage:
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent,
|
||||
event: QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
@@ -396,6 +407,7 @@ class WorkflowCycleManage:
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
@@ -540,6 +552,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
@@ -563,6 +576,7 @@ class WorkflowCycleManage:
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
@@ -601,6 +615,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -646,6 +661,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
@@ -664,6 +680,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
@@ -687,6 +704,7 @@ class WorkflowCycleManage:
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
@@ -770,6 +788,83 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
||||
) -> LoopNodeStartStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_next_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
||||
) -> LoopNodeNextStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_completed_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
||||
) -> LoopNodeCompletedStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
@@ -864,5 +959,6 @@ class WorkflowCycleManage:
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
),
|
||||
)
|
||||
|
||||
27
api/core/auto/config/custom.yaml
Normal file
27
api/core/auto/config/custom.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
# 自定义配置文件
|
||||
workflow_generator:
|
||||
# 用于生成工作流的模型配置
|
||||
models:
|
||||
default: my-gpt-4o-mini # 默认使用的模型
|
||||
available: # 可用的模型列表
|
||||
my-gpt-4o-mini:
|
||||
model_name: gpt-4o-mini
|
||||
base_url: https://api.pandalla.ai/v1
|
||||
key_path: ./openai_key
|
||||
max_tokens: 4096
|
||||
my-gpt-4o:
|
||||
model_name: gpt-4o
|
||||
base_url: https://api.pandalla.ai/v1
|
||||
key_path: ./openai_key
|
||||
max_tokens: 4096
|
||||
|
||||
# 调试配置
|
||||
debug:
|
||||
enabled: false # 默认不启用调试模式,可通过命令行参数 --debug 启用
|
||||
dir: debug/ # 调试信息保存目录
|
||||
save_options: # 调试信息保存选项
|
||||
prompt: true # 保存提示词
|
||||
response: true # 保存大模型响应
|
||||
json: true # 保存JSON解析过程
|
||||
workflow: true # 保存工作流生成过程
|
||||
case_id_format: "%Y%m%d_%H%M%S_%f" # 运行ID格式,使用datetime.strftime格式
|
||||
33
api/core/auto/config/default.yaml
Normal file
33
api/core/auto/config/default.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# 默认配置文件
|
||||
|
||||
# 工作流生成器配置
|
||||
workflow_generator:
|
||||
# 用于生成工作流的模型配置
|
||||
models:
|
||||
default: gpt-4 # 默认使用的模型
|
||||
available: # 可用的模型列表
|
||||
gpt-4:
|
||||
model_name: gpt-4
|
||||
base_url: https://api.openai.com/v1
|
||||
key_path: ./openai_key
|
||||
max_tokens: 8192
|
||||
gpt-4-turbo:
|
||||
model_name: gpt-4-1106-preview
|
||||
base_url: https://api.openai.com/v1
|
||||
key_path: ./openai_key
|
||||
max_tokens: 4096
|
||||
|
||||
# 工作流节点配置
|
||||
workflow_nodes:
|
||||
# LLM节点默认配置(使用 Dify 平台配置的模型)
|
||||
llm:
|
||||
provider: zhipuai
|
||||
model: glm-4-flash
|
||||
max_tokens: 16384
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
|
||||
# 输出配置
|
||||
output:
|
||||
dir: output/
|
||||
filename: generated_workflow.yml
|
||||
78
api/core/auto/node_types/__init__.py
Normal file
78
api/core/auto/node_types/__init__.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from .agent import AgentNodeType
|
||||
from .answer import AnswerNodeType
|
||||
from .assigner import AssignerNodeType
|
||||
from .code import CodeLanguage, CodeNodeType, OutputVar
|
||||
from .common import (
|
||||
BlockEnum,
|
||||
CommonEdgeType,
|
||||
CommonNodeType,
|
||||
CompleteEdge,
|
||||
CompleteNode,
|
||||
Context,
|
||||
InputVar,
|
||||
InputVarType,
|
||||
Memory,
|
||||
ModelConfig,
|
||||
PromptItem,
|
||||
PromptRole,
|
||||
ValueSelector,
|
||||
Variable,
|
||||
VarType,
|
||||
VisionSetting,
|
||||
)
|
||||
from .end import EndNodeType
|
||||
from .http import HttpNodeType
|
||||
from .if_else import IfElseNodeType
|
||||
from .iteration import IterationNodeType
|
||||
from .iteration_start import IterationStartNodeType
|
||||
from .knowledge_retrieval import KnowledgeRetrievalNodeType
|
||||
from .list_operator import ListFilterNodeType
|
||||
from .llm import LLMNodeType, VisionConfig
|
||||
from .note_node import NoteNodeType
|
||||
from .parameter_extractor import ParameterExtractorNodeType
|
||||
from .question_classifier import QuestionClassifierNodeType
|
||||
from .start import StartNodeType
|
||||
from .template_transform import TemplateTransformNodeType
|
||||
from .tool import ToolNodeType
|
||||
from .variable_assigner import VariableAssignerNodeType
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeType",
|
||||
"AnswerNodeType",
|
||||
"AssignerNodeType",
|
||||
"BlockEnum",
|
||||
"CodeLanguage",
|
||||
"CodeNodeType",
|
||||
"CommonEdgeType",
|
||||
"CommonNodeType",
|
||||
"CompleteEdge",
|
||||
"CompleteNode",
|
||||
"Context",
|
||||
"EndNodeType",
|
||||
"HttpNodeType",
|
||||
"IfElseNodeType",
|
||||
"InputVar",
|
||||
"InputVarType",
|
||||
"IterationNodeType",
|
||||
"IterationStartNodeType",
|
||||
"KnowledgeRetrievalNodeType",
|
||||
"LLMNodeType",
|
||||
"ListFilterNodeType",
|
||||
"Memory",
|
||||
"ModelConfig",
|
||||
"NoteNodeType",
|
||||
"OutputVar",
|
||||
"ParameterExtractorNodeType",
|
||||
"PromptItem",
|
||||
"PromptRole",
|
||||
"QuestionClassifierNodeType",
|
||||
"StartNodeType",
|
||||
"TemplateTransformNodeType",
|
||||
"ToolNodeType",
|
||||
"ValueSelector",
|
||||
"VarType",
|
||||
"Variable",
|
||||
"VariableAssignerNodeType",
|
||||
"VisionConfig",
|
||||
"VisionSetting",
|
||||
]
|
||||
34
api/core/auto/node_types/agent.py
Normal file
34
api/core/auto/node_types/agent.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType
|
||||
|
||||
# Introduce previously defined CommonNodeType and ToolVarInputs
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class ToolVarInputs(BaseModel):
|
||||
variable_name: Optional[str] = None
|
||||
default_value: Optional[Any] = None
|
||||
|
||||
|
||||
class AgentNodeType(CommonNodeType):
|
||||
agent_strategy_provider_name: Optional[str] = None
|
||||
agent_strategy_name: Optional[str] = None
|
||||
agent_strategy_label: Optional[str] = None
|
||||
agent_parameters: Optional[ToolVarInputs] = None
|
||||
output_schema: dict[str, Any]
|
||||
plugin_unique_identifier: Optional[str] = None
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = AgentNodeType(
|
||||
title="Example Agent",
|
||||
desc="An agent node example",
|
||||
type=BlockEnum.agent,
|
||||
output_schema={"key": "value"},
|
||||
agent_parameters=ToolVarInputs(variable_name="example_var", default_value="default"),
|
||||
)
|
||||
print(example_node)
|
||||
21
api/core/auto/node_types/answer.py
Normal file
21
api/core/auto/node_types/answer.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .common import BlockEnum, CommonNodeType, Variable
|
||||
|
||||
|
||||
class AnswerNodeType(CommonNodeType):
|
||||
variables: list[Variable]
|
||||
answer: str
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = AnswerNodeType(
|
||||
title="Example Answer Node",
|
||||
desc="An answer node example",
|
||||
type=BlockEnum.answer,
|
||||
answer="This is the answer",
|
||||
variables=[
|
||||
Variable(variable="var1", value_selector=["node1", "key1"]),
|
||||
Variable(variable="var2", value_selector=["node2", "key2"]),
|
||||
],
|
||||
)
|
||||
print(example_node)
|
||||
62
api/core/auto/node_types/assigner.py
Normal file
62
api/core/auto/node_types/assigner.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .common import BlockEnum, CommonNodeType
|
||||
|
||||
# Import previously defined CommonNodeType and ValueSelector
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class WriteMode(str, Enum):
|
||||
overwrite = "over-write"
|
||||
clear = "clear"
|
||||
append = "append"
|
||||
extend = "extend"
|
||||
set = "set"
|
||||
increment = "+="
|
||||
decrement = "-="
|
||||
multiply = "*="
|
||||
divide = "/="
|
||||
|
||||
|
||||
class AssignerNodeInputType(str, Enum):
|
||||
variable = "variable"
|
||||
constant = "constant"
|
||||
|
||||
|
||||
class AssignerNodeOperation(BaseModel):
|
||||
variable_selector: Any # Placeholder for ValueSelector type
|
||||
input_type: AssignerNodeInputType
|
||||
operation: WriteMode
|
||||
value: Any
|
||||
|
||||
|
||||
class AssignerNodeType(CommonNodeType):
|
||||
version: Optional[str] = Field(None, pattern="^[12]$") # Version is '1' or '2'
|
||||
items: list[AssignerNodeOperation]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = AssignerNodeType(
|
||||
title="Example Assigner Node",
|
||||
desc="An assigner node example",
|
||||
type=BlockEnum.variable_assigner,
|
||||
items=[
|
||||
AssignerNodeOperation(
|
||||
variable_selector={"nodeId": "node1", "key": "value"}, # Example ValueSelector
|
||||
input_type=AssignerNodeInputType.variable,
|
||||
operation=WriteMode.set,
|
||||
value="newValue",
|
||||
),
|
||||
AssignerNodeOperation(
|
||||
variable_selector={"nodeId": "node2", "key": "value"},
|
||||
input_type=AssignerNodeInputType.constant,
|
||||
operation=WriteMode.increment,
|
||||
value=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
print(example_node)
|
||||
56
api/core/auto/node_types/code.py
Normal file
56
api/core/auto/node_types/code.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.auto.node_types.common import BlockEnum, CommonNodeType, Variable, VarType
|
||||
|
||||
# 引入之前定义的 CommonNodeType、VarType 和 Variable
|
||||
# 假设它们在同一模块中定义
|
||||
|
||||
|
||||
class CodeLanguage(str, Enum):
|
||||
python3 = "python3"
|
||||
javascript = "javascript"
|
||||
json = "json"
|
||||
|
||||
|
||||
class OutputVar(BaseModel):
|
||||
type: VarType
|
||||
children: Optional[None] = None # 未来支持嵌套
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保正确序列化"""
|
||||
result = {"type": self.type.value if isinstance(self.type, Enum) else self.type}
|
||||
|
||||
if self.children is not None:
|
||||
result["children"] = self.children
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class CodeNodeType(CommonNodeType):
|
||||
variables: list[Variable]
|
||||
code_language: CodeLanguage
|
||||
code: str
|
||||
outputs: dict[str, OutputVar]
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 创建示例节点
|
||||
example_node = CodeNodeType(
|
||||
title="Example Code Node",
|
||||
desc="A code node example",
|
||||
type=BlockEnum.code,
|
||||
code_language=CodeLanguage.python3,
|
||||
code="print('Hello, World!')",
|
||||
outputs={
|
||||
"output1": OutputVar(type=VarType.string),
|
||||
"output2": OutputVar(type=VarType.number),
|
||||
},
|
||||
variables=[
|
||||
Variable(variable="var1", value_selector=["node1", "key1"]),
|
||||
],
|
||||
)
|
||||
print(example_node.get_all_required_fields())
|
||||
690
api/core/auto/node_types/common.py
Normal file
690
api/core/auto/node_types/common.py
Normal file
@@ -0,0 +1,690 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# BlockEnum 枚举
|
||||
class BlockEnum(str, Enum):
|
||||
start = "start"
|
||||
end = "end"
|
||||
answer = "answer"
|
||||
llm = "llm"
|
||||
knowledge_retrieval = "knowledge-retrieval"
|
||||
question_classifier = "question-classifier"
|
||||
if_else = "if-else"
|
||||
code = "code"
|
||||
template_transform = "template-transform"
|
||||
http_request = "http-request"
|
||||
variable_assigner = "variable-assigner"
|
||||
variable_aggregator = "variable-aggregator"
|
||||
tool = "tool"
|
||||
parameter_extractor = "parameter-extractor"
|
||||
iteration = "iteration"
|
||||
document_extractor = "document-extractor"
|
||||
list_operator = "list-operator"
|
||||
iteration_start = "iteration-start"
|
||||
assigner = "assigner" # is now named as VariableAssigner
|
||||
agent = "agent"
|
||||
|
||||
|
||||
# Error枚举
|
||||
class ErrorHandleMode(str, Enum):
|
||||
terminated = "terminated"
|
||||
continue_on_error = "continue-on-error"
|
||||
remove_abnormal_output = "remove-abnormal-output"
|
||||
|
||||
|
||||
class ErrorHandleTypeEnum(str, Enum):
|
||||
none = ("none",)
|
||||
failBranch = ("fail-branch",)
|
||||
defaultValue = ("default-value",)
|
||||
|
||||
|
||||
# Branch 类型
|
||||
class Branch(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
# NodeRunningStatus 枚举
|
||||
class NodeRunningStatus(str, Enum):
|
||||
not_start = "not-start"
|
||||
waiting = "waiting"
|
||||
running = "running"
|
||||
succeeded = "succeeded"
|
||||
failed = "failed"
|
||||
exception = "exception"
|
||||
retry = "retry"
|
||||
|
||||
|
||||
# 创建一个基类来统一CommonNodeType和CommonEdgeType的序列化逻辑
|
||||
class BaseType(BaseModel):
|
||||
"""基类,用于统一CommonNodeType和CommonEdgeType的序列化逻辑"""
|
||||
|
||||
def to_json(self) -> dict[str, Any]:
|
||||
"""
|
||||
将对象转换为JSON格式的字典,通过循环模型字段来构建JSON数据
|
||||
"""
|
||||
json_data = {}
|
||||
|
||||
# 获取模型的所有字段
|
||||
for field_name, field_value in self.__dict__.items():
|
||||
if field_value is not None:
|
||||
# 特殊处理Branch类型的列表
|
||||
if field_name == "_targetBranches" and field_value is not None:
|
||||
json_data[field_name] = [branch.dict(exclude_none=True) for branch in field_value]
|
||||
# 处理枚举类型
|
||||
elif isinstance(field_value, Enum):
|
||||
json_data[field_name] = field_value.value
|
||||
# 处理嵌套的Pydantic模型
|
||||
elif hasattr(field_value, "dict") and callable(field_value.dict):
|
||||
json_data[field_name] = field_value.dict(exclude_none=True)
|
||||
# 处理列表中的Pydantic模型
|
||||
elif isinstance(field_value, list):
|
||||
processed_list = []
|
||||
for item in field_value:
|
||||
if hasattr(item, "dict") and callable(item.dict):
|
||||
processed_list.append(item.dict(exclude_none=True))
|
||||
else:
|
||||
processed_list.append(item)
|
||||
json_data[field_name] = processed_list
|
||||
# 处理字典中的Pydantic模型
|
||||
elif isinstance(field_value, dict):
|
||||
processed_dict = {}
|
||||
for key, value in field_value.items():
|
||||
if hasattr(value, "dict") and callable(value.dict):
|
||||
processed_dict[key] = value.dict(exclude_none=True)
|
||||
else:
|
||||
processed_dict[key] = value
|
||||
json_data[field_name] = processed_dict
|
||||
# 其他字段直接添加
|
||||
else:
|
||||
json_data[field_name] = field_value
|
||||
|
||||
return json_data
|
||||
|
||||
|
||||
# CommonNodeType 类型
|
||||
class CommonNodeType(BaseType):
|
||||
_connectedSourceHandleIds: Optional[list[str]] = None
|
||||
_connectedTargetHandleIds: Optional[list[str]] = None
|
||||
_targetBranches: Optional[list[Branch]] = None
|
||||
_isSingleRun: Optional[bool] = None
|
||||
_runningStatus: Optional[NodeRunningStatus] = None
|
||||
_singleRunningStatus: Optional[NodeRunningStatus] = None
|
||||
_isCandidate: Optional[bool] = None
|
||||
_isBundled: Optional[bool] = None
|
||||
_children: Optional[list[str]] = None
|
||||
_isEntering: Optional[bool] = None
|
||||
_showAddVariablePopup: Optional[bool] = None
|
||||
_holdAddVariablePopup: Optional[bool] = None
|
||||
_iterationLength: Optional[int] = None
|
||||
_iterationIndex: Optional[int] = None
|
||||
_inParallelHovering: Optional[bool] = None
|
||||
isInIteration: Optional[bool] = None
|
||||
iteration_id: Optional[str] = None
|
||||
selected: Optional[bool] = None
|
||||
title: str
|
||||
desc: str
|
||||
type: BlockEnum
|
||||
width: Optional[float] = None
|
||||
height: Optional[float] = None
|
||||
|
||||
@classmethod
|
||||
def get_all_required_fields(cls) -> dict[str, str]:
|
||||
"""
|
||||
获取所有必选字段,包括从父类继承的字段
|
||||
这是一个类方法,可以通过类直接调用
|
||||
"""
|
||||
all_required_fields = {}
|
||||
|
||||
# 获取所有父类(除了 object 和 BaseModel)
|
||||
mro = [c for c in cls.__mro__ if c not in (object, BaseModel, BaseType)]
|
||||
|
||||
# 从父类到子类的顺序处理,这样子类的字段会覆盖父类的同名字段
|
||||
for class_type in reversed(mro):
|
||||
if hasattr(class_type, "__annotations__"):
|
||||
for field_name, field_info in class_type.__annotations__.items():
|
||||
# 检查字段是否有默认值
|
||||
has_default = hasattr(class_type, field_name)
|
||||
# 检查字段是否为可选类型
|
||||
is_optional = "Optional" in str(field_info)
|
||||
|
||||
# 如果字段没有默认值且不是Optional类型,则为必选字段
|
||||
if not has_default and not is_optional:
|
||||
all_required_fields[field_name] = str(field_info)
|
||||
|
||||
return all_required_fields
|
||||
|
||||
|
||||
# CommonEdgeType 类型
|
||||
class CommonEdgeType(BaseType):
|
||||
_hovering: Optional[bool] = None
|
||||
_connectedNodeIsHovering: Optional[bool] = None
|
||||
_connectedNodeIsSelected: Optional[bool] = None
|
||||
_run: Optional[bool] = None
|
||||
_isBundled: Optional[bool] = None
|
||||
isInIteration: Optional[bool] = None
|
||||
iteration_id: Optional[str] = None
|
||||
sourceType: BlockEnum
|
||||
targetType: BlockEnum
|
||||
|
||||
|
||||
class ValueSelector(BaseModel):
|
||||
"""Value selector for selecting values from other nodes."""
|
||||
|
||||
value: list[str] = Field(default_factory=list)
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,直接返回 value 列表"""
|
||||
return self.value
|
||||
|
||||
|
||||
# Add Context class for LLM node
|
||||
class Context(BaseModel):
|
||||
"""Context configuration for LLM node."""
|
||||
|
||||
enabled: bool = False
|
||||
variable_selector: Optional[ValueSelector] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保 variable_selector 字段正确序列化"""
|
||||
result = {"enabled": self.enabled}
|
||||
|
||||
if self.variable_selector:
|
||||
result["variable_selector"] = self.variable_selector.dict()
|
||||
else:
|
||||
result["variable_selector"] = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Variable 类型
|
||||
class Variable(BaseModel):
|
||||
"""
|
||||
变量类型,用于定义节点的输入/输出变量
|
||||
与Dify中的Variable类型保持一致
|
||||
"""
|
||||
|
||||
variable: str # 变量名
|
||||
label: Optional[Union[str, dict[str, str]]] = None # 变量标签,可以是字符串或对象
|
||||
value_selector: list[str] # 变量值选择器,格式为[nodeId, key]
|
||||
variable_type: Optional[str] = None # 变量类型,对应Dify中的VarType枚举
|
||||
value: Optional[str] = None # 变量值(常量值)
|
||||
options: Optional[list[str]] = None # 选项列表(用于select类型)
|
||||
required: Optional[bool] = None # 是否必填
|
||||
isParagraph: Optional[bool] = None # 是否为段落
|
||||
max_length: Optional[int] = None # 最大长度
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保正确序列化"""
|
||||
result = {"variable": self.variable}
|
||||
|
||||
if self.label is not None:
|
||||
result["label"] = self.label
|
||||
|
||||
if self.value_selector:
|
||||
result["value_selector"] = self.value_selector
|
||||
|
||||
if self.variable_type is not None:
|
||||
result["type"] = self.variable_type # 使用type而不是variable_type,与Dify保持一致
|
||||
|
||||
if self.value is not None:
|
||||
result["value"] = self.value
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.required is not None:
|
||||
result["required"] = self.required
|
||||
|
||||
if self.isParagraph is not None:
|
||||
result["isParagraph"] = self.isParagraph
|
||||
|
||||
if self.max_length is not None:
|
||||
result["max_length"] = self.max_length
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# EnvironmentVariable 类型
|
||||
class EnvironmentVariable(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value: Any
|
||||
value_type: str # Expecting to be either 'string', 'number', or 'secret'
|
||||
|
||||
|
||||
# ConversationVariable 类型
|
||||
class ConversationVariable(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
value_type: str
|
||||
value: Any
|
||||
description: str
|
||||
|
||||
|
||||
# GlobalVariable 类型
|
||||
class GlobalVariable(BaseModel):
|
||||
name: str
|
||||
value_type: str # Expecting to be either 'string' or 'number'
|
||||
description: str
|
||||
|
||||
|
||||
# VariableWithValue 类型
|
||||
class VariableWithValue(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
# InputVarType 枚举
|
||||
class InputVarType(str, Enum):
|
||||
text_input = "text-input"
|
||||
paragraph = "paragraph"
|
||||
select = "select"
|
||||
number = "number"
|
||||
url = "url"
|
||||
files = "files"
|
||||
json = "json"
|
||||
contexts = "contexts"
|
||||
iterator = "iterator"
|
||||
file = "file"
|
||||
file_list = "file-list"
|
||||
|
||||
|
||||
# InputVar 类型
|
||||
class InputVar(BaseModel):
|
||||
type: InputVarType
|
||||
label: Union[str, dict[str, Any]] # 可以是字符串或对象
|
||||
variable: str
|
||||
max_length: Optional[int] = None
|
||||
default: Optional[str] = None
|
||||
required: bool
|
||||
hint: Optional[str] = None
|
||||
options: Optional[list[str]] = None
|
||||
value_selector: Optional[list[str]] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保正确序列化"""
|
||||
result = {
|
||||
"type": self.type.value if isinstance(self.type, Enum) else self.type,
|
||||
"label": self.label,
|
||||
"variable": self.variable,
|
||||
"required": self.required,
|
||||
}
|
||||
|
||||
if self.max_length is not None:
|
||||
result["max_length"] = self.max_length
|
||||
|
||||
if self.default is not None:
|
||||
result["default"] = self.default
|
||||
|
||||
if self.hint is not None:
|
||||
result["hint"] = self.hint
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.value_selector is not None:
|
||||
result["value_selector"] = self.value_selector
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ModelConfig 类型
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any]
|
||||
|
||||
|
||||
# PromptRole 枚举
|
||||
class PromptRole(str, Enum):
|
||||
system = "system"
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
|
||||
|
||||
# EditionType 枚举
|
||||
class EditionType(str, Enum):
|
||||
basic = "basic"
|
||||
jinja2 = "jinja2"
|
||||
|
||||
|
||||
# PromptItem 类型
|
||||
class PromptItem(BaseModel):
|
||||
id: Optional[str] = None
|
||||
role: Optional[PromptRole] = None
|
||||
text: str
|
||||
edition_type: Optional[EditionType] = None
|
||||
jinja2_text: Optional[str] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保 role 字段正确序列化"""
|
||||
result = {"id": self.id, "text": self.text}
|
||||
|
||||
if self.role:
|
||||
result["role"] = self.role.value
|
||||
|
||||
if self.edition_type:
|
||||
result["edition_type"] = self.edition_type.value
|
||||
|
||||
if self.jinja2_text:
|
||||
result["jinja2_text"] = self.jinja2_text
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# MemoryRole 枚举
|
||||
class MemoryRole(str, Enum):
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
|
||||
|
||||
# RolePrefix 类型
|
||||
class RolePrefix(BaseModel):
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
|
||||
# Memory 类型
|
||||
class Memory(BaseModel):
|
||||
role_prefix: Optional[RolePrefix] = None
|
||||
window: dict[str, Any] # Expecting to have 'enabled' and 'size'
|
||||
query_prompt_template: str
|
||||
|
||||
|
||||
# VarType 枚举
|
||||
class VarType(str, Enum):
|
||||
string = "string"
|
||||
number = "number"
|
||||
secret = "secret"
|
||||
boolean = "boolean"
|
||||
object = "object"
|
||||
file = "file"
|
||||
array = "array"
|
||||
arrayString = "array[string]"
|
||||
arrayNumber = "array[number]"
|
||||
arrayObject = "array[object]"
|
||||
arrayFile = "array[file]"
|
||||
any = "any"
|
||||
|
||||
|
||||
# Var 类型
|
||||
class Var(BaseModel):
|
||||
variable: str
|
||||
type: VarType
|
||||
children: Optional[list["Var"]] = None # Self-reference
|
||||
isParagraph: Optional[bool] = None
|
||||
isSelect: Optional[bool] = None
|
||||
options: Optional[list[str]] = None
|
||||
required: Optional[bool] = None
|
||||
des: Optional[str] = None
|
||||
isException: Optional[bool] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保type字段正确序列化"""
|
||||
result = {"variable": self.variable, "type": self.type.value if isinstance(self.type, Enum) else self.type}
|
||||
|
||||
if self.children is not None:
|
||||
result["children"] = [child.dict() for child in self.children]
|
||||
|
||||
if self.isParagraph is not None:
|
||||
result["isParagraph"] = self.isParagraph
|
||||
|
||||
if self.isSelect is not None:
|
||||
result["isSelect"] = self.isSelect
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.required is not None:
|
||||
result["required"] = self.required
|
||||
|
||||
if self.des is not None:
|
||||
result["des"] = self.des
|
||||
|
||||
if self.isException is not None:
|
||||
result["isException"] = self.isException
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# NodeOutPutVar 类型
|
||||
class NodeOutPutVar(BaseModel):
|
||||
nodeId: str
|
||||
title: str
|
||||
vars: list[Var]
|
||||
isStartNode: Optional[bool] = None
|
||||
|
||||
|
||||
# Block 类型
|
||||
class Block(BaseModel):
|
||||
classification: Optional[str] = None
|
||||
type: BlockEnum
|
||||
title: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# NodeDefault 类型
|
||||
class NodeDefault(BaseModel):
|
||||
defaultValue: dict[str, Any]
|
||||
getAvailablePrevNodes: Any # Placeholder for function reference
|
||||
getAvailableNextNodes: Any # Placeholder for function reference
|
||||
checkValid: Any # Placeholder for function reference
|
||||
|
||||
|
||||
# OnSelectBlock 类型
|
||||
class OnSelectBlock(BaseModel):
|
||||
nodeType: BlockEnum
|
||||
additional_data: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
# WorkflowRunningStatus 枚举
|
||||
class WorkflowRunningStatus(str, Enum):
|
||||
waiting = "waiting"
|
||||
running = "running"
|
||||
succeeded = "succeeded"
|
||||
failed = "failed"
|
||||
stopped = "stopped"
|
||||
|
||||
|
||||
# WorkflowVersion 枚举
|
||||
class WorkflowVersion(str, Enum):
|
||||
draft = "draft"
|
||||
latest = "latest"
|
||||
|
||||
|
||||
# OnNodeAdd 类型
|
||||
class OnNodeAdd(BaseModel):
|
||||
nodeType: BlockEnum
|
||||
sourceHandle: Optional[str] = None
|
||||
targetHandle: Optional[str] = None
|
||||
toolDefaultValue: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
# CheckValidRes 类型
|
||||
class CheckValidRes(BaseModel):
|
||||
isValid: bool
|
||||
errorMessage: Optional[str] = None
|
||||
|
||||
|
||||
# RunFile 类型
|
||||
class RunFile(BaseModel):
|
||||
type: str
|
||||
transfer_method: list[str]
|
||||
url: Optional[str] = None
|
||||
upload_file_id: Optional[str] = None
|
||||
|
||||
|
||||
# WorkflowRunningData 类型
|
||||
class WorkflowRunningData(BaseModel):
|
||||
task_id: Optional[str] = None
|
||||
message_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
result: dict[str, Any] # Expecting a structured object
|
||||
tracing: Optional[list[dict[str, Any]]] = None # Placeholder for NodeTracing
|
||||
|
||||
|
||||
# HistoryWorkflowData 类型
|
||||
class HistoryWorkflowData(BaseModel):
|
||||
id: str
|
||||
sequence_number: int
|
||||
status: str
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
# ChangeType 枚举
|
||||
class ChangeType(str, Enum):
|
||||
changeVarName = "changeVarName"
|
||||
remove = "remove"
|
||||
|
||||
|
||||
# MoreInfo 类型
|
||||
class MoreInfo(BaseModel):
|
||||
type: ChangeType
|
||||
payload: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
# ToolWithProvider 类型
|
||||
class ToolWithProvider(BaseModel):
|
||||
tools: list[dict[str, Any]] # Placeholder for Tool type
|
||||
|
||||
|
||||
# SupportUploadFileTypes 枚举
|
||||
class SupportUploadFileTypes(str, Enum):
|
||||
image = "image"
|
||||
document = "document"
|
||||
audio = "audio"
|
||||
video = "video"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
# UploadFileSetting 类型
|
||||
class UploadFileSetting(BaseModel):
|
||||
allowed_file_upload_methods: list[str]
|
||||
allowed_file_types: list[SupportUploadFileTypes]
|
||||
allowed_file_extensions: Optional[list[str]] = None
|
||||
max_length: int
|
||||
number_limits: Optional[int] = None
|
||||
|
||||
|
||||
# VisionSetting 类型
|
||||
class VisionSetting(BaseModel):
|
||||
variable_selector: list[str]
|
||||
detail: dict[str, Any] # Placeholder for Resolution type
|
||||
|
||||
|
||||
# 创建一个基类来统一序列化逻辑
|
||||
class CompleteBase(BaseModel):
|
||||
"""基类,用于统一CompleteNode和CompleteEdge的序列化逻辑"""
|
||||
|
||||
def to_json(self):
|
||||
"""将对象转换为JSON格式的字典"""
|
||||
json_data = {}
|
||||
|
||||
# 获取模型的所有字段
|
||||
for field_name, field_value in self.__dict__.items():
|
||||
if field_value is not None:
|
||||
# 处理嵌套的数据对象
|
||||
if field_name == "data" and hasattr(field_value, "to_json"):
|
||||
json_data[field_name] = field_value.to_json()
|
||||
# 处理枚举类型
|
||||
elif isinstance(field_value, Enum):
|
||||
json_data[field_name] = field_value.value
|
||||
# 处理嵌套的Pydantic模型
|
||||
elif hasattr(field_value, "dict") and callable(field_value.dict):
|
||||
json_data[field_name] = field_value.dict(exclude_none=True)
|
||||
# 处理列表中的Pydantic模型
|
||||
elif isinstance(field_value, list):
|
||||
processed_list = []
|
||||
for item in field_value:
|
||||
if hasattr(item, "dict") and callable(item.dict):
|
||||
processed_list.append(item.dict(exclude_none=True))
|
||||
else:
|
||||
processed_list.append(item)
|
||||
json_data[field_name] = processed_list
|
||||
# 处理字典中的Pydantic模型
|
||||
elif isinstance(field_value, dict):
|
||||
processed_dict = {}
|
||||
for key, value in field_value.items():
|
||||
if hasattr(value, "dict") and callable(value.dict):
|
||||
processed_dict[key] = value.dict(exclude_none=True)
|
||||
else:
|
||||
processed_dict[key] = value
|
||||
json_data[field_name] = processed_dict
|
||||
# 其他字段直接添加
|
||||
else:
|
||||
json_data[field_name] = field_value
|
||||
|
||||
return json_data
|
||||
|
||||
def to_yaml(self):
|
||||
"""将对象转换为YAML格式的字符串"""
|
||||
return yaml.dump(self.to_json(), allow_unicode=True)
|
||||
|
||||
|
||||
class CompleteNode(CompleteBase):
|
||||
id: str
|
||||
position: dict
|
||||
height: int
|
||||
width: float
|
||||
positionAbsolute: dict
|
||||
selected: bool
|
||||
sourcePosition: Union[dict, str]
|
||||
targetPosition: Union[dict, str]
|
||||
type: str
|
||||
data: Optional[Union[CommonNodeType, None]] = None # Flexible field to store CommonNodeType or None
|
||||
|
||||
def add_data(self, data: Union[CommonNodeType, None]):
|
||||
self.data = data
|
||||
|
||||
def to_json(self):
|
||||
json_data = super().to_json()
|
||||
|
||||
# 特殊处理sourcePosition和targetPosition
|
||||
json_data["sourcePosition"] = "right" # 直接输出为字符串"right"
|
||||
json_data["targetPosition"] = "left" # 直接输出为字符串"left"
|
||||
|
||||
# 确保 width 是整数而不是浮点数
|
||||
if isinstance(json_data["width"], float):
|
||||
json_data["width"] = int(json_data["width"])
|
||||
|
||||
return json_data
|
||||
|
||||
|
||||
class CompleteEdge(CompleteBase):
|
||||
id: str
|
||||
source: str
|
||||
sourceHandle: str
|
||||
target: str
|
||||
targetHandle: str
|
||||
type: str
|
||||
zIndex: int
|
||||
data: Optional[Union[CommonEdgeType, None]] = None # Flexible field to store CommonEdgeType or None
|
||||
|
||||
def add_data(self, data: Union[CommonEdgeType, None]):
|
||||
self.data = data
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
# 这里可以添加示例数据进行验证
|
||||
common_node = CompleteNode(
|
||||
id="1740019130520",
|
||||
position={"x": 80, "y": 282},
|
||||
height=100,
|
||||
width=100,
|
||||
positionAbsolute={"x": 80, "y": 282},
|
||||
selected=True,
|
||||
sourcePosition={"x": 80, "y": 282},
|
||||
targetPosition={"x": 80, "y": 282},
|
||||
type="custom",
|
||||
)
|
||||
common_data = CommonNodeType(title="示例节点", desc="这是一个示例节点", type="")
|
||||
print(CommonNodeType.get_all_required_fields())
|
||||
common_node.add_data(common_data)
|
||||
# print(common_node)
|
||||
22
api/core/auto/node_types/end.py
Normal file
22
api/core/auto/node_types/end.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .common import BlockEnum, CommonNodeType, Variable
|
||||
|
||||
# Import previously defined CommonNodeType and Variable
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class EndNodeType(CommonNodeType):
|
||||
outputs: list[Variable]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = EndNodeType(
|
||||
title="Example End Node",
|
||||
desc="An end node example",
|
||||
type=BlockEnum.end,
|
||||
outputs=[
|
||||
Variable(variable="outputVar1", value_selector=["node1", "key1"]),
|
||||
Variable(variable="outputVar2", value_selector=["node2", "key2"]),
|
||||
],
|
||||
)
|
||||
print(example_node)
|
||||
127
api/core/auto/node_types/http.py
Normal file
127
api/core/auto/node_types/http.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector, Variable
|
||||
|
||||
# Import previously defined CommonNodeType, ValueSelector, and Variable
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class Method(str, Enum):
|
||||
"""HTTP request methods."""
|
||||
|
||||
get = "get"
|
||||
post = "post"
|
||||
head = "head"
|
||||
patch = "patch"
|
||||
put = "put"
|
||||
delete = "delete"
|
||||
|
||||
|
||||
class BodyType(str, Enum):
|
||||
"""HTTP request body types."""
|
||||
|
||||
none = "none"
|
||||
formData = "form-data"
|
||||
xWwwFormUrlencoded = "x-www-form-urlencoded"
|
||||
rawText = "raw-text"
|
||||
json = "json"
|
||||
binary = "binary"
|
||||
|
||||
|
||||
class BodyPayloadValueType(str, Enum):
|
||||
"""Types of values in body payload."""
|
||||
|
||||
text = "text"
|
||||
file = "file"
|
||||
|
||||
|
||||
class BodyPayload(BaseModel):
|
||||
"""Body payload item for HTTP requests."""
|
||||
|
||||
id: Optional[str] = None
|
||||
key: Optional[str] = None
|
||||
type: BodyPayloadValueType
|
||||
file: Optional[ValueSelector] = None # Used when type is file
|
||||
value: Optional[str] = None # Used when type is text
|
||||
|
||||
|
||||
class Body(BaseModel):
|
||||
"""HTTP request body configuration."""
|
||||
|
||||
type: BodyType
|
||||
data: Union[str, list[BodyPayload]] # string is deprecated, will convert to BodyPayload
|
||||
|
||||
|
||||
class AuthorizationType(str, Enum):
|
||||
"""HTTP authorization types."""
|
||||
|
||||
none = "no-auth"
|
||||
apiKey = "api-key"
|
||||
|
||||
|
||||
class APIType(str, Enum):
|
||||
"""API key types."""
|
||||
|
||||
basic = "basic"
|
||||
bearer = "bearer"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""Authorization configuration."""
|
||||
|
||||
type: APIType
|
||||
api_key: str
|
||||
header: Optional[str] = None
|
||||
|
||||
|
||||
class Authorization(BaseModel):
|
||||
"""HTTP authorization settings."""
|
||||
|
||||
type: AuthorizationType
|
||||
config: Optional[AuthConfig] = None
|
||||
|
||||
|
||||
class Timeout(BaseModel):
|
||||
"""HTTP request timeout settings."""
|
||||
|
||||
connect: Optional[int] = None
|
||||
read: Optional[int] = None
|
||||
write: Optional[int] = None
|
||||
max_connect_timeout: Optional[int] = None
|
||||
max_read_timeout: Optional[int] = None
|
||||
max_write_timeout: Optional[int] = None
|
||||
|
||||
|
||||
class HttpNodeType(CommonNodeType):
|
||||
"""HTTP request node type implementation."""
|
||||
|
||||
variables: list[Variable]
|
||||
method: Method
|
||||
url: str
|
||||
headers: str
|
||||
params: str
|
||||
body: Body
|
||||
authorization: Authorization
|
||||
timeout: Timeout
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = HttpNodeType(
|
||||
title="Example HTTP Node",
|
||||
desc="An HTTP request node example",
|
||||
type=BlockEnum.http_request,
|
||||
variables=[Variable(variable="var1", value_selector=["node1", "key1"])],
|
||||
method=Method.get,
|
||||
url="https://api.example.com/data",
|
||||
headers="{}",
|
||||
params="{}",
|
||||
body=Body(type=BodyType.none, data=[]),
|
||||
authorization=Authorization(type=AuthorizationType.none),
|
||||
timeout=Timeout(connect=30, read=30, write=30),
|
||||
)
|
||||
print(example_node)
|
||||
99
api/core/auto/node_types/if_else.py
Normal file
99
api/core/auto/node_types/if_else.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector, VarType
|
||||
from .tool import VarType as NumberVarType
|
||||
|
||||
# Import previously defined CommonNodeType, ValueSelector, Var, and VarType
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class LogicalOperator(str, Enum):
|
||||
and_ = "and"
|
||||
or_ = "or"
|
||||
|
||||
|
||||
class ComparisonOperator(str, Enum):
|
||||
contains = "contains"
|
||||
notContains = "not contains"
|
||||
startWith = "start with"
|
||||
endWith = "end with"
|
||||
is_ = "is"
|
||||
isNot = "is not"
|
||||
empty = "empty"
|
||||
notEmpty = "not empty"
|
||||
equal = "="
|
||||
notEqual = "≠"
|
||||
largerThan = ">"
|
||||
lessThan = "<"
|
||||
largerThanOrEqual = "≥"
|
||||
lessThanOrEqual = "≤"
|
||||
isNull = "is null"
|
||||
isNotNull = "is not null"
|
||||
in_ = "in"
|
||||
notIn = "not in"
|
||||
allOf = "all of"
|
||||
exists = "exists"
|
||||
notExists = "not exists"
|
||||
equals = "=" # Alias for equal for compatibility
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
id: str
|
||||
varType: VarType
|
||||
variable_selector: Optional[ValueSelector]
|
||||
key: Optional[str] = None # Sub variable key
|
||||
comparison_operator: Optional[ComparisonOperator] = None
|
||||
value: Union[str, list[str]]
|
||||
numberVarType: Optional[NumberVarType]
|
||||
sub_variable_condition: Optional["CaseItem"] = None # Recursive reference
|
||||
|
||||
|
||||
class CaseItem(BaseModel):
|
||||
case_id: str
|
||||
logical_operator: LogicalOperator
|
||||
conditions: list[Condition]
|
||||
|
||||
|
||||
class IfElseNodeType(CommonNodeType):
|
||||
logical_operator: Optional[LogicalOperator] = None
|
||||
conditions: Optional[list[Condition]] = None
|
||||
cases: list[CaseItem]
|
||||
isInIteration: bool
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = IfElseNodeType(
|
||||
title="Example IfElse Node",
|
||||
desc="An if-else node example",
|
||||
type=BlockEnum.if_else,
|
||||
logical_operator=LogicalOperator.and_,
|
||||
conditions=[
|
||||
Condition(
|
||||
id="condition1",
|
||||
varType=VarType.string,
|
||||
variable_selector={"nodeId": "varNode", "key": "value"},
|
||||
comparison_operator=ComparisonOperator.is_,
|
||||
value="exampleValue",
|
||||
)
|
||||
],
|
||||
cases=[
|
||||
CaseItem(
|
||||
case_id="case1",
|
||||
logical_operator=LogicalOperator.or_,
|
||||
conditions=[
|
||||
Condition(
|
||||
id="condition2",
|
||||
varType=VarType.number,
|
||||
value="10",
|
||||
comparison_operator=ComparisonOperator.largerThan,
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
isInIteration=True,
|
||||
)
|
||||
print(example_node)
|
||||
45
api/core/auto/node_types/iteration.py
Normal file
45
api/core/auto/node_types/iteration.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector, VarType
|
||||
|
||||
|
||||
class ErrorHandleMode(str, Enum):
|
||||
"""Error handling modes for iteration."""
|
||||
|
||||
terminated = "terminated"
|
||||
continue_on_error = "continue-on-error"
|
||||
remove_abnormal_output = "remove-abnormal-output"
|
||||
|
||||
|
||||
class IterationNodeType(CommonNodeType):
|
||||
"""Iteration node type implementation."""
|
||||
|
||||
startNodeType: Optional[BlockEnum] = None
|
||||
start_node_id: str # Start node ID in the iteration
|
||||
iteration_id: Optional[str] = None
|
||||
iterator_selector: ValueSelector
|
||||
output_selector: ValueSelector
|
||||
output_type: VarType # Output type
|
||||
is_parallel: bool # Open the parallel mode or not
|
||||
parallel_nums: int # The numbers of parallel
|
||||
error_handle_mode: ErrorHandleMode # How to handle error in the iteration
|
||||
_isShowTips: bool # Show tips when answer node in parallel mode iteration
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = IterationNodeType(
|
||||
title="Example Iteration Node",
|
||||
desc="An iteration node example",
|
||||
type=BlockEnum.iteration,
|
||||
start_node_id="startNode1",
|
||||
iterator_selector=ValueSelector(value=["iteratorNode", "value"]),
|
||||
output_selector=ValueSelector(value=["outputNode", "value"]),
|
||||
output_type=VarType.string,
|
||||
is_parallel=True,
|
||||
parallel_nums=5,
|
||||
error_handle_mode=ErrorHandleMode.continue_on_error,
|
||||
_isShowTips=True,
|
||||
)
|
||||
print(example_node)
|
||||
25
api/core/auto/node_types/iteration_start.py
Normal file
25
api/core/auto/node_types/iteration_start.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from .common import BlockEnum, CommonNodeType
|
||||
|
||||
# 引入之前定义的 CommonNodeType
|
||||
# 假设它们在同一模块中定义
|
||||
|
||||
|
||||
class IterationStartNodeType(CommonNodeType):
|
||||
"""
|
||||
Iteration Start node type implementation.
|
||||
|
||||
This node type is used as the starting point within an iteration block.
|
||||
It inherits all properties from CommonNodeType without adding any additional fields.
|
||||
"""
|
||||
|
||||
pass # 仅仅继承 CommonNodeType,无其他字段
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = IterationStartNodeType(
|
||||
title="Example Iteration Start Node",
|
||||
desc="An iteration start node example",
|
||||
type=BlockEnum.iteration_start,
|
||||
)
|
||||
print(example_node)
|
||||
115
api/core/auto/node_types/knowledge_retrieval.py
Normal file
115
api/core/auto/node_types/knowledge_retrieval.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ModelConfig, ValueSelector
|
||||
|
||||
|
||||
class RetrieveType(str, Enum):
|
||||
"""Retrieval mode types."""
|
||||
|
||||
single = "single"
|
||||
multiple = "multiple"
|
||||
|
||||
|
||||
class RerankingModeEnum(str, Enum):
|
||||
"""Reranking mode types."""
|
||||
|
||||
simple = "simple"
|
||||
advanced = "advanced"
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
"""Vector weight settings."""
|
||||
|
||||
vector_weight: float
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class KeywordSetting(BaseModel):
|
||||
"""Keyword weight settings."""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class Weights(BaseModel):
|
||||
"""Weight configuration for retrieval."""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
|
||||
class RerankingModel(BaseModel):
|
||||
"""Reranking model configuration."""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
class MultipleRetrievalConfig(BaseModel):
|
||||
"""Configuration for multiple retrieval mode."""
|
||||
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = None
|
||||
reranking_model: Optional[RerankingModel] = None
|
||||
reranking_mode: Optional[RerankingModeEnum] = None
|
||||
weights: Optional[Weights] = None
|
||||
reranking_enable: Optional[bool] = None
|
||||
|
||||
|
||||
class SingleRetrievalConfig(BaseModel):
|
||||
"""Configuration for single retrieval mode."""
|
||||
|
||||
model: ModelConfig
|
||||
|
||||
|
||||
class DataSet(BaseModel):
|
||||
"""Dataset information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class KnowledgeRetrievalNodeType(CommonNodeType):
|
||||
"""Knowledge retrieval node type implementation."""
|
||||
|
||||
query_variable_selector: ValueSelector
|
||||
dataset_ids: list[str]
|
||||
retrieval_mode: RetrieveType
|
||||
multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None
|
||||
single_retrieval_config: Optional[SingleRetrievalConfig] = None
|
||||
_datasets: Optional[list[DataSet]] = None
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = KnowledgeRetrievalNodeType(
|
||||
title="Example Knowledge Retrieval Node",
|
||||
desc="A knowledge retrieval node example",
|
||||
type=BlockEnum.knowledge_retrieval,
|
||||
query_variable_selector=ValueSelector(value=["queryNode", "query"]),
|
||||
dataset_ids=["dataset1", "dataset2"],
|
||||
retrieval_mode=RetrieveType.multiple,
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=10,
|
||||
score_threshold=0.5,
|
||||
reranking_model=RerankingModel(provider="example_provider", model="example_model"),
|
||||
reranking_mode=RerankingModeEnum.simple,
|
||||
weights=Weights(
|
||||
vector_setting=VectorSetting(
|
||||
vector_weight=0.7, embedding_provider_name="provider1", embedding_model_name="model1"
|
||||
),
|
||||
keyword_setting=KeywordSetting(keyword_weight=0.3),
|
||||
),
|
||||
reranking_enable=True,
|
||||
),
|
||||
single_retrieval_config=SingleRetrievalConfig(
|
||||
model=ModelConfig(
|
||||
provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7}
|
||||
)
|
||||
),
|
||||
)
|
||||
print(example_node)
|
||||
73
api/core/auto/node_types/list_operator.py
Normal file
73
api/core/auto/node_types/list_operator.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector, VarType
|
||||
|
||||
# Import ComparisonOperator from if_else.py
|
||||
from .if_else import ComparisonOperator
|
||||
|
||||
|
||||
class OrderBy(str, Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class Limit(BaseModel):
|
||||
enabled: bool
|
||||
size: Optional[int] = None
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
key: str
|
||||
comparison_operator: ComparisonOperator
|
||||
value: Union[str, int, list[str]]
|
||||
|
||||
|
||||
class FilterBy(BaseModel):
|
||||
enabled: bool
|
||||
conditions: list[Condition]
|
||||
|
||||
|
||||
class ExtractBy(BaseModel):
|
||||
enabled: bool
|
||||
serial: Optional[str] = None
|
||||
|
||||
|
||||
class OrderByConfig(BaseModel):
|
||||
enabled: bool
|
||||
key: Union[ValueSelector, str]
|
||||
value: OrderBy
|
||||
|
||||
|
||||
class ListFilterNodeType(CommonNodeType):
|
||||
"""List filter/operator node type implementation."""
|
||||
|
||||
variable: ValueSelector
|
||||
var_type: VarType
|
||||
item_var_type: VarType
|
||||
filter_by: FilterBy
|
||||
extract_by: ExtractBy
|
||||
order_by: OrderByConfig
|
||||
limit: Limit
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = ListFilterNodeType(
|
||||
title="Example List Filter Node",
|
||||
desc="A list filter node example",
|
||||
type=BlockEnum.list_operator, # Fixed: use list_operator instead of list_filter
|
||||
variable=ValueSelector(value=["varNode", "value"]),
|
||||
var_type=VarType.string,
|
||||
item_var_type=VarType.number,
|
||||
filter_by=FilterBy(
|
||||
enabled=True,
|
||||
conditions=[Condition(key="status", comparison_operator=ComparisonOperator.equals, value="active")],
|
||||
),
|
||||
extract_by=ExtractBy(enabled=True, serial="serial_1"),
|
||||
order_by=OrderByConfig(enabled=True, key="created_at", value=OrderBy.DESC),
|
||||
limit=Limit(enabled=True, size=100),
|
||||
)
|
||||
print(example_node)
|
||||
66
api/core/auto/node_types/llm.py
Normal file
66
api/core/auto/node_types/llm.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import (
|
||||
BlockEnum,
|
||||
CommonNodeType,
|
||||
Context,
|
||||
Memory,
|
||||
ModelConfig,
|
||||
PromptItem,
|
||||
Variable,
|
||||
VisionSetting,
|
||||
)
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Configuration for prompt template variables."""
|
||||
|
||||
jinja2_variables: Optional[list[Variable]] = None
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""Configuration for vision settings."""
|
||||
|
||||
enabled: bool = False
|
||||
configs: Optional[VisionSetting] = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
"""自定义序列化方法,确保正确序列化"""
|
||||
result = {"enabled": self.enabled}
|
||||
|
||||
if self.configs:
|
||||
result["configs"] = self.configs.dict()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class LLMNodeType(CommonNodeType):
|
||||
"""LLM node type implementation."""
|
||||
|
||||
model: ModelConfig
|
||||
prompt_template: Union[list[PromptItem], PromptItem]
|
||||
prompt_config: Optional[PromptConfig] = None
|
||||
memory: Optional[Memory] = None
|
||||
context: Optional[Context] = Context(enabled=False, variable_selector=None)
|
||||
vision: Optional[VisionConfig] = VisionConfig(enabled=False)
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = LLMNodeType(
|
||||
title="Example LLM Node",
|
||||
desc="A LLM node example",
|
||||
type=BlockEnum.llm,
|
||||
model=ModelConfig(provider="zhipuai", name="glm-4-flash", mode="chat", completion_params={"temperature": 0.7}),
|
||||
prompt_template=[
|
||||
PromptItem(
|
||||
id="system-id", role="system", text="你是一个代码工程师,你会根据用户的需求给出用户所需要的函数"
|
||||
),
|
||||
PromptItem(id="user-id", role="user", text="给出两数相加的python 函数代码,函数名 func 不要添加其他内容"),
|
||||
],
|
||||
context=Context(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
)
|
||||
print(example_node)
|
||||
38
api/core/auto/node_types/note_node.py
Normal file
38
api/core/auto/node_types/note_node.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from enum import Enum
|
||||
|
||||
from .common import BlockEnum, CommonNodeType
|
||||
|
||||
# Import previously defined CommonNodeType
|
||||
# Assume it is defined in the same module
|
||||
|
||||
|
||||
class NoteTheme(str, Enum):
|
||||
blue = "blue"
|
||||
cyan = "cyan"
|
||||
green = "green"
|
||||
yellow = "yellow"
|
||||
pink = "pink"
|
||||
violet = "violet"
|
||||
|
||||
|
||||
class NoteNodeType(CommonNodeType):
|
||||
"""Custom note node type implementation."""
|
||||
|
||||
text: str
|
||||
theme: NoteTheme
|
||||
author: str
|
||||
showAuthor: bool
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = NoteNodeType(
|
||||
title="Example Note Node",
|
||||
desc="A note node example",
|
||||
type=BlockEnum.custom_note,
|
||||
text="This is a note.",
|
||||
theme=NoteTheme.green,
|
||||
author="John Doe",
|
||||
showAuthor=True,
|
||||
)
|
||||
print(example_node)
|
||||
85
api/core/auto/node_types/parameter_extractor.py
Normal file
85
api/core/auto/node_types/parameter_extractor.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, Memory, ModelConfig, ValueSelector, VisionSetting
|
||||
|
||||
# Import previously defined CommonNodeType, Memory, ModelConfig, ValueSelector, and VisionSetting
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class ParamType(str, Enum):
|
||||
"""Parameter types for extraction."""
|
||||
|
||||
string = "string"
|
||||
number = "number"
|
||||
bool = "bool"
|
||||
select = "select"
|
||||
arrayString = "array[string]"
|
||||
arrayNumber = "array[number]"
|
||||
arrayObject = "array[object]"
|
||||
|
||||
|
||||
class Param(BaseModel):
|
||||
"""Parameter definition for extraction."""
|
||||
|
||||
name: str
|
||||
type: ParamType
|
||||
options: Optional[list[str]] = None
|
||||
description: str
|
||||
required: Optional[bool] = None
|
||||
|
||||
|
||||
class ReasoningModeType(str, Enum):
|
||||
"""Reasoning mode types for parameter extraction."""
|
||||
|
||||
prompt = "prompt"
|
||||
functionCall = "function_call"
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""Vision configuration."""
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[VisionSetting] = None
|
||||
|
||||
|
||||
class ParameterExtractorNodeType(CommonNodeType):
|
||||
"""Parameter extractor node type implementation."""
|
||||
|
||||
model: ModelConfig
|
||||
query: ValueSelector
|
||||
reasoning_mode: ReasoningModeType
|
||||
parameters: List[Param]
|
||||
instruction: str
|
||||
memory: Optional[Memory] = None
|
||||
vision: VisionConfig
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = ParameterExtractorNodeType(
|
||||
title="Example Parameter Extractor Node",
|
||||
desc="A parameter extractor node example",
|
||||
type=BlockEnum.parameter_extractor,
|
||||
model=ModelConfig(
|
||||
provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7}
|
||||
),
|
||||
query=ValueSelector(value=["queryNode", "value"]),
|
||||
reasoning_mode=ReasoningModeType.prompt,
|
||||
parameters=[
|
||||
Param(name="param1", type=ParamType.string, description="This is a string parameter", required=True),
|
||||
Param(
|
||||
name="param2",
|
||||
type=ParamType.number,
|
||||
options=["1", "2", "3"],
|
||||
description="This is a number parameter",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
instruction="Please extract the parameters from the input.",
|
||||
memory=Memory(window={"enabled": True, "size": 10}, query_prompt_template="Extract parameters from: {{query}}"),
|
||||
vision=VisionConfig(enabled=True, configs={"setting": "example_setting"}),
|
||||
)
|
||||
print(example_node)
|
||||
51
api/core/auto/node_types/question_classifier.py
Normal file
51
api/core/auto/node_types/question_classifier.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, Memory, ModelConfig, ValueSelector, VisionSetting
|
||||
|
||||
# Import previously defined CommonNodeType, Memory, ModelConfig, ValueSelector, and VisionSetting
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class Topic(BaseModel):
|
||||
"""Topic for classification."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
"""Vision configuration."""
|
||||
|
||||
enabled: bool
|
||||
configs: Optional[VisionSetting] = None
|
||||
|
||||
|
||||
class QuestionClassifierNodeType(CommonNodeType):
|
||||
"""Question classifier node type implementation."""
|
||||
|
||||
query_variable_selector: ValueSelector
|
||||
model: ModelConfig
|
||||
classes: list[Topic]
|
||||
instruction: str
|
||||
memory: Optional[Memory] = None
|
||||
vision: VisionConfig
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = QuestionClassifierNodeType(
|
||||
title="Example Question Classifier Node",
|
||||
desc="A question classifier node example",
|
||||
type=BlockEnum.question_classifier,
|
||||
query_variable_selector=ValueSelector(value=["queryNode", "value"]),
|
||||
model=ModelConfig(
|
||||
provider="example_provider", name="example_model", mode="chat", completion_params={"temperature": 0.7}
|
||||
),
|
||||
classes=[Topic(id="1", name="Science"), Topic(id="2", name="Mathematics"), Topic(id="3", name="Literature")],
|
||||
instruction="Classify the given question into the appropriate topic.",
|
||||
memory=Memory(window={"enabled": True, "size": 10}, query_prompt_template="Classify this question: {{query}}"),
|
||||
vision=VisionConfig(enabled=True, configs={"setting": "example_setting"}),
|
||||
)
|
||||
print(example_node)
|
||||
22
api/core/auto/node_types/start.py
Normal file
22
api/core/auto/node_types/start.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .common import BlockEnum, CommonNodeType, InputVar
|
||||
|
||||
# Import previously defined CommonNodeType and InputVar
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class StartNodeType(CommonNodeType):
|
||||
variables: list[InputVar]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = StartNodeType(
|
||||
title="Example Start Node",
|
||||
desc="A start node example",
|
||||
type=BlockEnum.start,
|
||||
variables=[
|
||||
InputVar(type="text-input", label="Input 1", variable="input1", required=True),
|
||||
InputVar(type="number", label="Input 2", variable="input2", required=True),
|
||||
],
|
||||
)
|
||||
print(example_node)
|
||||
26
api/core/auto/node_types/template_transform.py
Normal file
26
api/core/auto/node_types/template_transform.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from .common import BlockEnum, CommonNodeType, Variable
|
||||
|
||||
# 引入之前定义的 CommonNodeType 和 Variable
|
||||
# 假设它们在同一模块中定义
|
||||
|
||||
|
||||
class TemplateTransformNodeType(CommonNodeType):
|
||||
"""Template transform node type implementation."""
|
||||
|
||||
variables: list[Variable]
|
||||
template: str
|
||||
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
example_node = TemplateTransformNodeType(
|
||||
title="Example Template Transform Node",
|
||||
desc="A template transform node example",
|
||||
type=BlockEnum.template_transform,
|
||||
variables=[
|
||||
Variable(variable="var1", value_selector=["node1", "key1"]),
|
||||
Variable(variable="var2", value_selector=["node2", "key2"]),
|
||||
],
|
||||
template="Hello, {{ var1 }}! You have {{ var2 }} new messages.",
|
||||
)
|
||||
print(example_node)
|
||||
54
api/core/auto/node_types/tool.py
Normal file
54
api/core/auto/node_types/tool.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector
|
||||
|
||||
# Import previously defined CommonNodeType and ValueSelector
|
||||
# Assume they are defined in the same module
|
||||
|
||||
|
||||
class VarType(str, Enum):
|
||||
variable = "variable"
|
||||
constant = "constant"
|
||||
mixed = "mixed"
|
||||
|
||||
|
||||
class ToolVarInputs(BaseModel):
|
||||
type: VarType
|
||||
value: Optional[Union[str, ValueSelector, Any]] = None
|
||||
|
||||
|
||||
class ToolNodeType(CommonNodeType):
|
||||
"""Tool node type implementation."""
|
||||
|
||||
provider_id: str
|
||||
provider_type: Any # Placeholder for CollectionType
|
||||
provider_name: str
|
||||
tool_name: str
|
||||
tool_label: str
|
||||
tool_parameters: dict[str, ToolVarInputs]
|
||||
tool_configurations: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = ToolNodeType(
|
||||
title="Example Tool Node",
|
||||
desc="A tool node example",
|
||||
type=BlockEnum.tool,
|
||||
provider_id="12345",
|
||||
provider_type="some_collection_type", # Placeholder for CollectionType
|
||||
provider_name="Example Provider",
|
||||
tool_name="Example Tool",
|
||||
tool_label="Example Tool Label",
|
||||
tool_parameters={
|
||||
"input1": ToolVarInputs(type=VarType.variable, value="some_value"),
|
||||
"input2": ToolVarInputs(type=VarType.constant, value="constant_value"),
|
||||
},
|
||||
tool_configurations={"config1": "value1", "config2": {"nested": "value2"}},
|
||||
output_schema={"output1": "string", "output2": "number"},
|
||||
)
|
||||
print(example_node.json(indent=2)) # Print as JSON format for viewing
|
||||
56
api/core/auto/node_types/variable_assigner.py
Normal file
56
api/core/auto/node_types/variable_assigner.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .common import BlockEnum, CommonNodeType, ValueSelector, VarType
|
||||
|
||||
|
||||
class VarGroupItem(BaseModel):
|
||||
"""Variable group item configuration."""
|
||||
|
||||
output_type: VarType
|
||||
variables: list[ValueSelector]
|
||||
|
||||
|
||||
class GroupConfig(VarGroupItem):
|
||||
"""Group configuration for advanced settings."""
|
||||
|
||||
group_name: str
|
||||
groupId: str
|
||||
|
||||
|
||||
class AdvancedSettings(BaseModel):
|
||||
"""Advanced settings for variable assigner."""
|
||||
|
||||
group_enabled: bool
|
||||
groups: list[GroupConfig]
|
||||
|
||||
|
||||
class VariableAssignerNodeType(CommonNodeType, VarGroupItem):
|
||||
"""Variable assigner node type implementation."""
|
||||
|
||||
advanced_settings: AdvancedSettings
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
example_node = VariableAssignerNodeType(
|
||||
title="Example Variable Assigner Node",
|
||||
desc="A variable assigner node example",
|
||||
type=BlockEnum.variable_assigner,
|
||||
output_type=VarType.string,
|
||||
variables=[ValueSelector(value=["varNode1", "value1"]), ValueSelector(value=["varNode2", "value2"])],
|
||||
advanced_settings=AdvancedSettings(
|
||||
group_enabled=True,
|
||||
groups=[
|
||||
GroupConfig(
|
||||
group_name="Group 1",
|
||||
groupId="group1",
|
||||
output_type=VarType.number,
|
||||
variables=[ValueSelector(value=["varNode3", "value3"])],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
print(example_node.json(indent=2)) # Print as JSON format for viewing
|
||||
239
api/core/auto/output/emotion_analysis_workflow.yml
Normal file
239
api/core/auto/output/emotion_analysis_workflow.yml
Normal file
@@ -0,0 +1,239 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: 情绪分析工作流
|
||||
use_icon_as_answer_icon: false
|
||||
kind: app
|
||||
version: 0.1.2
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_size_limit: 10
|
||||
video_file_size_limit: 100
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- id: 1740019130520-source-1740019130521-target
|
||||
source: '1740019130520'
|
||||
sourceHandle: source
|
||||
target: '1740019130521'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
- id: 1740019130521-source-1740019130522-target
|
||||
source: '1740019130521'
|
||||
sourceHandle: source
|
||||
target: '1740019130522'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: llm
|
||||
targetType: code
|
||||
- id: 1740019130522-source-1740019130523-target
|
||||
source: '1740019130522'
|
||||
sourceHandle: source
|
||||
target: '1740019130523'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: code
|
||||
targetType: template-transform
|
||||
- id: 1740019130523-source-1740019130524-target
|
||||
source: '1740019130523'
|
||||
sourceHandle: source
|
||||
target: '1740019130524'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: template-transform
|
||||
targetType: end
|
||||
nodes:
|
||||
- id: '1740019130520'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
height: 116
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 开始节点
|
||||
desc: 开始节点,接收用户输入的文本。
|
||||
type: start
|
||||
variables:
|
||||
- type: text-input
|
||||
label: input_text
|
||||
variable: input_text
|
||||
required: true
|
||||
max_length: 48
|
||||
options: []
|
||||
- id: '1740019130521'
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
height: 98
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: LLM节点
|
||||
desc: LLM节点分析文本情绪,识别出积极、消极或中性情绪。
|
||||
type: llm
|
||||
model:
|
||||
provider: zhipuai
|
||||
name: glm-4-flash
|
||||
mode: chat
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
prompt_template:
|
||||
- id: 1740019130521-system
|
||||
text: 请分析以下文本的情绪,并返回情绪类型(积极、消极或中性)。
|
||||
role: system
|
||||
- id: 1740019130521-user
|
||||
text: 分析此文本的情绪:{{input_text}}
|
||||
role: user
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
vision:
|
||||
enabled: false
|
||||
- id: '1740019130522'
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 代码节点
|
||||
desc: 代码节点将根据LLM分析的结果处理情绪类型。
|
||||
type: code
|
||||
variables:
|
||||
- variable: emotion
|
||||
value_selector:
|
||||
- '1740019130521'
|
||||
- emotion
|
||||
code_language: python3
|
||||
code: "def analyze_sentiment(emotion):\n if emotion == 'positive':\n \
|
||||
\ return '积极'\n elif emotion == 'negative':\n return '消极'\n\
|
||||
\ else:\n return '中性'\n\nemotion = '{{emotion}}'\nresult = analyze_sentiment(emotion)\n\
|
||||
return {'result': result}"
|
||||
outputs:
|
||||
sentiment_result:
|
||||
type: string
|
||||
- id: '1740019130523'
|
||||
position:
|
||||
x: 980
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 980
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 模板节点
|
||||
desc: 模板节点将情绪分析结果格式化输出。
|
||||
type: template-transform
|
||||
variables:
|
||||
- variable: sentiment_result
|
||||
value_selector:
|
||||
- '1740019130522'
|
||||
- sentiment_result
|
||||
template: 文本的情绪分析结果为:{{sentiment_result}}
|
||||
- id: '1740019130524'
|
||||
position:
|
||||
x: 1280
|
||||
y: 282
|
||||
height: 90
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 1280
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 结束节点
|
||||
desc: 结束节点,返回格式化后的情绪分析结果。
|
||||
type: end
|
||||
outputs:
|
||||
- variable: output
|
||||
value_selector:
|
||||
- '1740019130523'
|
||||
- output
|
||||
viewport:
|
||||
x: 92.96659905656679
|
||||
y: 79.13437154762897
|
||||
zoom: 0.9002006986311041
|
||||
247
api/core/auto/output/test_workflow.yml
Normal file
247
api/core/auto/output/test_workflow.yml
Normal file
@@ -0,0 +1,247 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: 计算两个数字之和
|
||||
use_icon_as_answer_icon: false
|
||||
kind: app
|
||||
version: 0.1.2
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_size_limit: 10
|
||||
video_file_size_limit: 100
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- id: 1740019130520-source-1740019130521-target
|
||||
source: '1740019130520'
|
||||
sourceHandle: source
|
||||
target: '1740019130521'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
- id: 1740019130521-source-1740019130522-target
|
||||
source: '1740019130521'
|
||||
sourceHandle: source
|
||||
target: '1740019130522'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: llm
|
||||
targetType: code
|
||||
- id: 1740019130522-source-1740019130523-target
|
||||
source: '1740019130522'
|
||||
sourceHandle: source
|
||||
target: '1740019130523'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: code
|
||||
targetType: template-transform
|
||||
- id: 1740019130523-source-1740019130524-target
|
||||
source: '1740019130523'
|
||||
sourceHandle: source
|
||||
target: '1740019130524'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: template-transform
|
||||
targetType: end
|
||||
nodes:
|
||||
- id: '1740019130520'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
height: 116
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 开始节点
|
||||
desc: 开始节点,接收两个数字输入参数。
|
||||
type: start
|
||||
variables:
|
||||
- type: number
|
||||
label: num1
|
||||
variable: num1
|
||||
required: true
|
||||
max_length: 48
|
||||
options: []
|
||||
- type: number
|
||||
label: num2
|
||||
variable: num2
|
||||
required: true
|
||||
max_length: 48
|
||||
options: []
|
||||
- id: '1740019130521'
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
height: 98
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: LLM节点
|
||||
desc: LLM节点,根据输入的两个数字生成计算它们之和的Python函数。
|
||||
type: llm
|
||||
model:
|
||||
provider: openai
|
||||
name: gpt-4
|
||||
mode: chat
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
prompt_template:
|
||||
- id: 1740019130521-system
|
||||
text: 你是一个Python开发助手,请根据以下输入生成一个计算两数之和的Python函数。
|
||||
role: system
|
||||
- id: 1740019130521-user
|
||||
text: 请为两个数字{{num1}}和{{num2}}生成一个Python函数,计算它们的和。
|
||||
role: user
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
vision:
|
||||
enabled: false
|
||||
- id: '1740019130522'
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 代码节点
|
||||
desc: 代码节点,执行LLM生成的Python函数,并计算结果。
|
||||
type: code
|
||||
variables:
|
||||
- variable: num1
|
||||
value_selector:
|
||||
- '1740019130520'
|
||||
- num1
|
||||
- variable: num2
|
||||
value_selector:
|
||||
- '1740019130520'
|
||||
- num2
|
||||
code_language: python3
|
||||
code: "def calculate_sum(num1, num2):\n return num1 + num2\n\nresult =\
|
||||
\ calculate_sum({{num1}}, {{num2}})\nreturn result"
|
||||
outputs:
|
||||
result:
|
||||
type: number
|
||||
- id: '1740019130523'
|
||||
position:
|
||||
x: 980
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 980
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 模板节点
|
||||
desc: 模板节点,将计算结果格式化为输出字符串。
|
||||
type: template-transform
|
||||
variables:
|
||||
- variable: result
|
||||
value_selector:
|
||||
- '1740019130522'
|
||||
- result
|
||||
template: '计算结果为: {{result}}'
|
||||
- id: '1740019130524'
|
||||
position:
|
||||
x: 1280
|
||||
y: 282
|
||||
height: 90
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 1280
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 结束节点
|
||||
desc: 结束节点,输出格式化后的结果。
|
||||
type: end
|
||||
outputs:
|
||||
- variable: output
|
||||
value_selector:
|
||||
- '1740019130523'
|
||||
- output
|
||||
viewport:
|
||||
x: 92.96659905656679
|
||||
y: 79.13437154762897
|
||||
zoom: 0.9002006986311041
|
||||
262
api/core/auto/output/text_analysis_workflow.yml
Normal file
262
api/core/auto/output/text_analysis_workflow.yml
Normal file
@@ -0,0 +1,262 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: workflow
|
||||
name: 文本分析工作流
|
||||
use_icon_as_answer_icon: false
|
||||
kind: app
|
||||
version: 0.1.2
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
fileUploadConfig:
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
image_file_size_limit: 10
|
||||
video_file_size_limit: 100
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- id: 1740019130520-source-1740019130521-target
|
||||
source: '1740019130520'
|
||||
sourceHandle: source
|
||||
target: '1740019130521'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
- id: 1740019130520-source-1740019130522-target
|
||||
source: '1740019130520'
|
||||
sourceHandle: source
|
||||
target: '1740019130522'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: start
|
||||
targetType: code
|
||||
- id: 1740019130521-source-1740019130523-target
|
||||
source: '1740019130521'
|
||||
sourceHandle: source
|
||||
target: '1740019130523'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: llm
|
||||
targetType: template-transform
|
||||
- id: 1740019130522-source-1740019130523-target
|
||||
source: '1740019130522'
|
||||
sourceHandle: source
|
||||
target: '1740019130523'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: code
|
||||
targetType: template-transform
|
||||
- id: 1740019130523-source-1740019130524-target
|
||||
source: '1740019130523'
|
||||
sourceHandle: source
|
||||
target: '1740019130524'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
data:
|
||||
isInIteration: false
|
||||
sourceType: template-transform
|
||||
targetType: end
|
||||
nodes:
|
||||
- id: '1740019130520'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
height: 116
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 开始节点
|
||||
desc: 接收用户输入的文本参数
|
||||
type: start
|
||||
variables:
|
||||
- type: text-input
|
||||
label: user_text
|
||||
variable: user_text
|
||||
required: true
|
||||
max_length: 48
|
||||
options: []
|
||||
- id: '1740019130521'
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
height: 98
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: LLM节点
|
||||
desc: 使用大语言模型进行情感分析,返回文本的情感结果
|
||||
type: llm
|
||||
model:
|
||||
provider: zhipuai
|
||||
name: glm-4-flash
|
||||
mode: chat
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
prompt_template:
|
||||
- id: 1740019130521-system
|
||||
text: 请分析以下文本的情感,返回积极、消极或中性
|
||||
role: system
|
||||
- id: 1740019130521-user
|
||||
text: '{{user_text}}'
|
||||
role: user
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
vision:
|
||||
enabled: false
|
||||
- id: '1740019130522'
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 代码节点
|
||||
desc: 计算文本的统计信息,包括字符数、单词数和句子数
|
||||
type: code
|
||||
variables:
|
||||
- variable: text_for_analysis
|
||||
value_selector:
|
||||
- '1740019130520'
|
||||
- user_text
|
||||
code_language: python3
|
||||
code: "import re\n\ndef main(text):\n char_count = len(text)\n word_count\
|
||||
\ = len(text.split())\n sentence_count = len(re.findall(r'[.!?]', text))\n\
|
||||
\ return {'char_count': char_count, 'word_count': word_count, 'sentence_count':\
|
||||
\ sentence_count}"
|
||||
outputs:
|
||||
text_statistics:
|
||||
type: object
|
||||
- id: '1740019130523'
|
||||
position:
|
||||
x: 980
|
||||
y: 282
|
||||
height: 54
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 980
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 模板节点
|
||||
desc: 将情感分析结果和统计信息组合成格式化报告
|
||||
type: template-transform
|
||||
variables:
|
||||
- variable: sentiment_result
|
||||
value_selector:
|
||||
- '1740019130521'
|
||||
- sentiment_result
|
||||
- variable: text_statistics
|
||||
value_selector:
|
||||
- '1740019130522'
|
||||
- text_statistics
|
||||
template: '情感分析结果:{{sentiment_result}}
|
||||
|
||||
文本统计信息:
|
||||
|
||||
字符数:{{text_statistics.char_count}}
|
||||
|
||||
单词数:{{text_statistics.word_count}}
|
||||
|
||||
句子数:{{text_statistics.sentence_count}}'
|
||||
- id: '1740019130524'
|
||||
position:
|
||||
x: 1280
|
||||
y: 282
|
||||
height: 90
|
||||
width: 244
|
||||
positionAbsolute:
|
||||
x: 1280
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
data:
|
||||
title: 结束节点
|
||||
desc: 返回最终的分析报告
|
||||
type: end
|
||||
outputs:
|
||||
- variable: final_report
|
||||
value_selector:
|
||||
- '1740019130523'
|
||||
- output
|
||||
viewport:
|
||||
x: 92.96659905656679
|
||||
y: 79.13437154762897
|
||||
zoom: 0.9002006986311041
|
||||
8
api/core/auto/workflow_generator/__init__.py
Normal file
8
api/core/auto/workflow_generator/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
工作流生成器包
|
||||
用于根据用户需求生成Dify工作流
|
||||
"""
|
||||
|
||||
from .workflow_generator import WorkflowGenerator
|
||||
|
||||
__all__ = ["WorkflowGenerator"]
|
||||
9
api/core/auto/workflow_generator/generators/__init__.py
Normal file
9
api/core/auto/workflow_generator/generators/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
节点和边生成器包
|
||||
"""
|
||||
|
||||
from .edge_generator import EdgeGenerator
|
||||
from .layout_engine import LayoutEngine
|
||||
from .node_generator import NodeGenerator
|
||||
|
||||
__all__ = ["EdgeGenerator", "LayoutEngine", "NodeGenerator"]
|
||||
101
api/core/auto/workflow_generator/generators/edge_generator.py
Normal file
101
api/core/auto/workflow_generator/generators/edge_generator.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Edge Generator
|
||||
Used to generate edges in the workflow
|
||||
"""
|
||||
|
||||
from core.auto.node_types.common import CommonEdgeType, CompleteEdge, CompleteNode
|
||||
from core.auto.workflow_generator.models.workflow_description import ConnectionDescription
|
||||
|
||||
|
||||
class EdgeGenerator:
|
||||
"""Edge generator for creating workflow edges"""
|
||||
|
||||
@staticmethod
|
||||
def create_edges(nodes: list[CompleteNode], connections: list[ConnectionDescription]) -> list[CompleteEdge]:
|
||||
"""
|
||||
Create edges based on nodes and connection information
|
||||
|
||||
Args:
|
||||
nodes: list of nodes
|
||||
connections: list of connection descriptions
|
||||
|
||||
Returns:
|
||||
list of edges
|
||||
"""
|
||||
edges = []
|
||||
|
||||
# If connection information is provided, create edges based on it
|
||||
if connections:
|
||||
for connection in connections:
|
||||
source_id = connection.source
|
||||
target_id = connection.target
|
||||
|
||||
if not source_id or not target_id:
|
||||
continue
|
||||
|
||||
# Find source and target nodes
|
||||
source_node = next((node for node in nodes if node.id == source_id), None)
|
||||
target_node = next((node for node in nodes if node.id == target_id), None)
|
||||
|
||||
if not source_node or not target_node:
|
||||
continue
|
||||
|
||||
# Get node types
|
||||
source_type = source_node.data.type
|
||||
target_type = target_node.data.type
|
||||
|
||||
# Create edge
|
||||
edge_id = f"{source_id}-source-{target_id}-target"
|
||||
|
||||
# Create edge data
|
||||
edge_data = CommonEdgeType(isInIteration=False, sourceType=source_type, targetType=target_type)
|
||||
|
||||
# Create complete edge
|
||||
edge = CompleteEdge(
|
||||
id=edge_id,
|
||||
source=source_id,
|
||||
sourceHandle="source",
|
||||
target=target_id,
|
||||
targetHandle="target",
|
||||
type="custom",
|
||||
zIndex=0,
|
||||
)
|
||||
|
||||
# Add edge data
|
||||
edge.add_data(edge_data)
|
||||
|
||||
edges.append(edge)
|
||||
# If no connection information is provided, automatically create edges
|
||||
else:
|
||||
# Create edges based on node order
|
||||
for i in range(len(nodes) - 1):
|
||||
source_node = nodes[i]
|
||||
target_node = nodes[i + 1]
|
||||
|
||||
# Get node types
|
||||
source_type = source_node.data.type
|
||||
target_type = target_node.data.type
|
||||
|
||||
# Create edge
|
||||
edge_id = f"{source_node.id}-source-{target_node.id}-target"
|
||||
|
||||
# Create edge data
|
||||
edge_data = CommonEdgeType(isInIteration=False, sourceType=source_type, targetType=target_type)
|
||||
|
||||
# Create complete edge
|
||||
edge = CompleteEdge(
|
||||
id=edge_id,
|
||||
source=source_node.id,
|
||||
sourceHandle="source",
|
||||
target=target_node.id,
|
||||
targetHandle="target",
|
||||
type="custom",
|
||||
zIndex=0,
|
||||
)
|
||||
|
||||
# Add edge data
|
||||
edge.add_data(edge_data)
|
||||
|
||||
edges.append(edge)
|
||||
|
||||
return edges
|
||||
77
api/core/auto/workflow_generator/generators/layout_engine.py
Normal file
77
api/core/auto/workflow_generator/generators/layout_engine.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Layout Engine
|
||||
Used to arrange the positions of nodes and edges
|
||||
"""
|
||||
|
||||
from core.auto.node_types.common import CompleteEdge, CompleteNode
|
||||
|
||||
|
||||
class LayoutEngine:
|
||||
"""Layout engine"""
|
||||
|
||||
@staticmethod
|
||||
def apply_layout(nodes: list[CompleteNode]) -> None:
|
||||
"""
|
||||
Apply layout, arranging nodes in a row
|
||||
|
||||
Args:
|
||||
nodes: list of nodes
|
||||
"""
|
||||
# Simple linear layout, arranging nodes from left to right
|
||||
x_position = 80
|
||||
y_position = 282
|
||||
|
||||
for node in nodes:
|
||||
node.position = {"x": x_position, "y": y_position}
|
||||
node.positionAbsolute = {"x": x_position, "y": y_position}
|
||||
|
||||
# Update position for the next node
|
||||
x_position += 300 # Horizontal spacing between nodes
|
||||
|
||||
@staticmethod
|
||||
def apply_topological_layout(nodes: list[CompleteNode], edges: list[CompleteEdge]) -> None:
|
||||
"""
|
||||
Apply topological sort layout, arranging nodes based on their dependencies
|
||||
|
||||
Args:
|
||||
nodes: list of nodes
|
||||
edges: list of edges
|
||||
"""
|
||||
# Create mapping from node ID to node
|
||||
node_map = {node.id: node for node in nodes}
|
||||
|
||||
# Create adjacency list
|
||||
adjacency_list = {node.id: [] for node in nodes}
|
||||
for edge in edges:
|
||||
adjacency_list[edge.source].append(edge.target)
|
||||
|
||||
# Create in-degree table
|
||||
in_degree = {node.id: 0 for node in nodes}
|
||||
for source, targets in adjacency_list.items():
|
||||
for target in targets:
|
||||
in_degree[target] += 1
|
||||
|
||||
# Topological sort
|
||||
queue = [node_id for node_id, degree in in_degree.items() if degree == 0]
|
||||
sorted_nodes = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
sorted_nodes.append(current)
|
||||
|
||||
for neighbor in adjacency_list[current]:
|
||||
in_degree[neighbor] -= 1
|
||||
if in_degree[neighbor] == 0:
|
||||
queue.append(neighbor)
|
||||
|
||||
# Apply layout
|
||||
x_position = 80
|
||||
y_position = 282
|
||||
|
||||
for node_id in sorted_nodes:
|
||||
node = node_map[node_id]
|
||||
node.position = {"x": x_position, "y": y_position}
|
||||
node.positionAbsolute = {"x": x_position, "y": y_position}
|
||||
|
||||
# Update position for the next node
|
||||
x_position += 300 # Horizontal spacing between nodes
|
||||
446
api/core/auto/workflow_generator/generators/node_generator.py
Normal file
446
api/core/auto/workflow_generator/generators/node_generator.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Node Generator
|
||||
Generate nodes based on workflow description
|
||||
"""
|
||||
|
||||
from core.auto.node_types.code import CodeLanguage, CodeNodeType, OutputVar
|
||||
from core.auto.node_types.common import (
|
||||
BlockEnum,
|
||||
CompleteNode,
|
||||
Context,
|
||||
InputVar,
|
||||
ModelConfig,
|
||||
PromptItem,
|
||||
PromptRole,
|
||||
ValueSelector,
|
||||
Variable,
|
||||
)
|
||||
from core.auto.node_types.end import EndNodeType
|
||||
from core.auto.node_types.llm import LLMNodeType, VisionConfig
|
||||
from core.auto.node_types.start import StartNodeType
|
||||
from core.auto.node_types.template_transform import TemplateTransformNodeType
|
||||
from core.auto.workflow_generator.models.workflow_description import NodeDescription
|
||||
from core.auto.workflow_generator.utils.prompts import DEFAULT_MODEL_CONFIG, DEFAULT_SYSTEM_PROMPT
|
||||
from core.auto.workflow_generator.utils.type_mapper import map_string_to_var_type, map_var_type_to_input_type
|
||||
|
||||
|
||||
class NodeGenerator:
|
||||
"""Node generator for creating workflow nodes"""
|
||||
|
||||
@staticmethod
|
||||
def create_nodes(node_descriptions: list[NodeDescription]) -> list[CompleteNode]:
|
||||
"""
|
||||
Create nodes based on node descriptions
|
||||
|
||||
Args:
|
||||
node_descriptions: list of node descriptions
|
||||
|
||||
Returns:
|
||||
list of nodes
|
||||
"""
|
||||
nodes = []
|
||||
|
||||
for node_desc in node_descriptions:
|
||||
node_type = node_desc.type
|
||||
|
||||
if node_type == "start":
|
||||
node = NodeGenerator._create_start_node(node_desc)
|
||||
elif node_type == "llm":
|
||||
node = NodeGenerator._create_llm_node(node_desc)
|
||||
elif node_type == "code":
|
||||
node = NodeGenerator._create_code_node(node_desc)
|
||||
elif node_type == "template":
|
||||
node = NodeGenerator._create_template_node(node_desc)
|
||||
elif node_type == "end":
|
||||
node = NodeGenerator._create_end_node(node_desc)
|
||||
else:
|
||||
raise ValueError(f"Unsupported node type: {node_type}")
|
||||
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
@staticmethod
|
||||
def _create_start_node(node_desc: NodeDescription) -> CompleteNode:
|
||||
"""Create start node"""
|
||||
variables = []
|
||||
|
||||
for var in node_desc.variables or []:
|
||||
input_var = InputVar(
|
||||
type=map_var_type_to_input_type(var.type),
|
||||
label=var.name,
|
||||
variable=var.name,
|
||||
required=var.required,
|
||||
max_length=48,
|
||||
options=[],
|
||||
)
|
||||
variables.append(input_var)
|
||||
|
||||
start_node = StartNodeType(
|
||||
title=node_desc.title, desc=node_desc.description or "", type=BlockEnum.start, variables=variables
|
||||
)
|
||||
|
||||
return CompleteNode(
|
||||
id=node_desc.id,
|
||||
type="custom",
|
||||
position={"x": 0, "y": 0}, # Temporary position, will be updated later
|
||||
height=118, # Increase height to match reference file
|
||||
width=244,
|
||||
positionAbsolute={"x": 0, "y": 0},
|
||||
selected=False,
|
||||
sourcePosition="right",
|
||||
targetPosition="left",
|
||||
data=start_node,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_llm_node(node_desc: NodeDescription) -> CompleteNode:
|
||||
"""Create LLM node"""
|
||||
# Build prompt template
|
||||
prompt_template = []
|
||||
|
||||
# Add system prompt
|
||||
system_prompt = node_desc.system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
prompt_template.append(PromptItem(id=f"{node_desc.id}-system", role=PromptRole.system, text=system_prompt))
|
||||
|
||||
# Add user prompt
|
||||
user_prompt = node_desc.user_prompt or "Please answer these questions:"
|
||||
|
||||
# Build variable list
|
||||
variables = []
|
||||
for var in node_desc.variables or []:
|
||||
source_node = var.source_node or ""
|
||||
source_variable = var.source_variable or ""
|
||||
|
||||
print(
|
||||
f"DEBUG: Processing variable {var.name}, source_node={source_node}, source_variable={source_variable}"
|
||||
)
|
||||
|
||||
# If source node is an LLM node, ensure source_variable is 'text'
|
||||
if source_node:
|
||||
# Check if the source node is an LLM node by checking connections
|
||||
# This is a simple heuristic - if the source node is connected to a node with 'llm' in its ID
|
||||
# or if the source node has 'llm' in its ID, assume it's an LLM node
|
||||
if "llm" in source_node.lower():
|
||||
print(f"DEBUG: Found LLM node {source_node}")
|
||||
if source_variable != "text":
|
||||
old_var = source_variable
|
||||
source_variable = "text" # LLM nodes output variable is always 'text'
|
||||
print(
|
||||
f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}" # noqa: E501
|
||||
)
|
||||
|
||||
# Check if the user prompt already contains correctly formatted variable references
|
||||
# Variable references in LLM nodes should be in the format {{#nodeID.variableName#}}
|
||||
correct_format = f"{{{{#{source_node}.{source_variable}#}}}}"
|
||||
simple_format = f"{{{{{var.name}}}}}"
|
||||
|
||||
# If simple format is used in the prompt, replace it with the correct format
|
||||
if simple_format in user_prompt and source_node and source_variable:
|
||||
user_prompt = user_prompt.replace(simple_format, correct_format)
|
||||
|
||||
variable = Variable(variable=var.name, value_selector=[source_node, source_variable])
|
||||
variables.append(variable)
|
||||
|
||||
# Update user prompt
|
||||
prompt_template.append(PromptItem(id=f"{node_desc.id}-user", role=PromptRole.user, text=user_prompt))
|
||||
|
||||
# Use default model configuration, prioritize configuration in node description
|
||||
provider = node_desc.provider or DEFAULT_MODEL_CONFIG["provider"]
|
||||
model = node_desc.model or DEFAULT_MODEL_CONFIG["model"]
|
||||
|
||||
llm_node = LLMNodeType(
|
||||
title=node_desc.title,
|
||||
desc=node_desc.description or "",
|
||||
type=BlockEnum.llm,
|
||||
model=ModelConfig(
|
||||
provider=provider,
|
||||
name=model,
|
||||
mode=DEFAULT_MODEL_CONFIG["mode"],
|
||||
completion_params=DEFAULT_MODEL_CONFIG["completion_params"],
|
||||
),
|
||||
prompt_template=prompt_template,
|
||||
variables=variables,
|
||||
context=Context(enabled=False, variable_selector=ValueSelector(value=[])),
|
||||
vision=VisionConfig(enabled=False),
|
||||
)
|
||||
|
||||
return CompleteNode(
|
||||
id=node_desc.id,
|
||||
type="custom",
|
||||
position={"x": 0, "y": 0}, # Temporary position, will be updated later
|
||||
height=126, # Increase height to match reference file
|
||||
width=244,
|
||||
positionAbsolute={"x": 0, "y": 0},
|
||||
selected=False,
|
||||
sourcePosition="right",
|
||||
targetPosition="left",
|
||||
data=llm_node,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_code_node(node_desc: NodeDescription) -> CompleteNode:
|
||||
"""Create code node"""
|
||||
# Build variable list and function parameter names
|
||||
variables = []
|
||||
var_names = []
|
||||
var_mapping = {} # Used to store mapping from variable names to function parameter names
|
||||
|
||||
# First, identify all LLM nodes in the workflow
|
||||
llm_nodes = set()
|
||||
for connection in node_desc.workflow_description.connections:
|
||||
for node in node_desc.workflow_description.nodes:
|
||||
if node.id == connection.source and node.type.lower() == "llm":
|
||||
llm_nodes.add(node.id)
|
||||
|
||||
for var in node_desc.variables or []:
|
||||
source_node = var.source_node or ""
|
||||
source_variable = var.source_variable or ""
|
||||
|
||||
# Check if source node is an LLM node and warn if source_variable is not 'text'
|
||||
if source_node in llm_nodes and source_variable != "text":
|
||||
print(
|
||||
f"WARNING: LLM node {source_node} output variable should be 'text', but got '{source_variable}'. This may cause issues in Dify." # noqa: E501
|
||||
)
|
||||
print(" Consider changing the source_variable to 'text' in your workflow description.")
|
||||
# Auto-fix: Always use 'text' as the source variable for LLM nodes
|
||||
old_var = source_variable
|
||||
source_variable = "text"
|
||||
print(f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}")
|
||||
elif source_node and "llm" in source_node.lower() and source_variable != "text":
|
||||
# Fallback heuristic check based on node ID
|
||||
print(
|
||||
f"WARNING: Node {source_node} appears to be an LLM node based on its ID, but source_variable is not 'text'." # noqa: E501
|
||||
)
|
||||
print(" Consider changing the source_variable to 'text' in your workflow description.")
|
||||
# Auto-fix: Always use 'text' as the source variable for LLM nodes
|
||||
old_var = source_variable
|
||||
source_variable = "text"
|
||||
print(f"Auto-fixing: Changed source variable from '{old_var}' to 'text' for LLM node {source_node}")
|
||||
|
||||
# Use variable name as function parameter name
|
||||
variable_name = var.name # Variable name defined in this node
|
||||
param_name = variable_name # Function parameter name must match variable name
|
||||
|
||||
# Validate variable name format
|
||||
if not variable_name.replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
f"Invalid variable name: {variable_name}. Variable names must only contain letters, numbers, and underscores." # noqa: E501
|
||||
)
|
||||
if not variable_name[0].isalpha() and variable_name[0] != "_":
|
||||
raise ValueError(
|
||||
f"Invalid variable name: {variable_name}. Variable names must start with a letter or underscore."
|
||||
)
|
||||
|
||||
var_names.append(param_name)
|
||||
var_mapping[variable_name] = param_name
|
||||
|
||||
variable = Variable(variable=variable_name, value_selector=[source_node, source_variable])
|
||||
variables.append(variable)
|
||||
|
||||
# Build output
|
||||
outputs = {}
|
||||
for output in node_desc.outputs or []:
|
||||
# Validate output variable name format
|
||||
if not output.name.replace("_", "").isalnum():
|
||||
raise ValueError(
|
||||
f"Invalid output variable name: {output.name}. Output names must only contain letters, numbers, and underscores." # noqa: E501
|
||||
)
|
||||
if not output.name[0].isalpha() and output.name[0] != "_":
|
||||
raise ValueError(
|
||||
f"Invalid output variable name: {output.name}. Output names must start with a letter or underscore."
|
||||
)
|
||||
|
||||
outputs[output.name] = OutputVar(type=map_string_to_var_type(output.type))
|
||||
|
||||
# Generate code, ensure function parameters match variable names, return values match output names
|
||||
output_names = [output.name for output in node_desc.outputs or []]
|
||||
|
||||
# Build function parameter list
|
||||
params_str = ", ".join(var_names) if var_names else ""
|
||||
|
||||
# Build return value dictionary
|
||||
return_dict = {}
|
||||
for output_name in output_names:
|
||||
# Use the first variable as the return value by default
|
||||
return_dict[output_name] = var_names[0] if var_names else f'"{output_name}"'
|
||||
|
||||
return_dict_str = ", ".join([f'"{k}": {v}' for k, v in return_dict.items()])
|
||||
|
||||
# Default code template, ensure return dictionary type matches output variable
|
||||
default_code = f"""def main({params_str}):
|
||||
# Write your code here
|
||||
# Process input variables
|
||||
|
||||
# Return a dictionary, key names must match variable names defined in outputs
|
||||
return {{{return_dict_str}}}"""
|
||||
|
||||
# If custom code is provided, ensure it meets the specifications
|
||||
if node_desc.code:
|
||||
custom_code = node_desc.code
|
||||
# Check if it contains main function definition
|
||||
if not custom_code.strip().startswith("def main("):
|
||||
# Try to fix the code by adding main function with correct parameters
|
||||
custom_code = f"def main({params_str}):\n" + custom_code.strip()
|
||||
else:
|
||||
# Extract function parameters from the existing main function
|
||||
import re
|
||||
|
||||
func_params = re.search(r"def\s+main\s*\((.*?)\)", custom_code)
|
||||
if func_params:
|
||||
existing_params = [p.strip() for p in func_params.group(1).split(",") if p.strip()]
|
||||
# Verify that all required parameters are present
|
||||
missing_params = set(var_names) - set(existing_params)
|
||||
if missing_params:
|
||||
# 尝试修复代码,将函数参数替换为正确的参数名
|
||||
old_params = func_params.group(1)
|
||||
new_params = params_str
|
||||
custom_code = custom_code.replace(f"def main({old_params})", f"def main({new_params})")
|
||||
print(
|
||||
f"Warning: Fixed missing parameters in code node: {', '.join(missing_params)}. Function parameters must match variable names defined in this node." # noqa: E501
|
||||
)
|
||||
|
||||
# Check if the return value is a dictionary and keys match output variables
|
||||
for output_name in output_names:
|
||||
if f'"{output_name}"' not in custom_code and f"'{output_name}'" not in custom_code:
|
||||
# Code may not meet specifications, use default code
|
||||
custom_code = default_code
|
||||
break
|
||||
|
||||
# Use fixed code
|
||||
code = custom_code
|
||||
else:
|
||||
code = default_code
|
||||
|
||||
code_node = CodeNodeType(
|
||||
title=node_desc.title,
|
||||
desc=node_desc.description or "",
|
||||
type=BlockEnum.code,
|
||||
code_language=CodeLanguage.python3,
|
||||
code=code,
|
||||
variables=variables,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
return CompleteNode(
|
||||
id=node_desc.id,
|
||||
type="custom",
|
||||
position={"x": 0, "y": 0}, # Temporary position, will be updated later
|
||||
height=82, # Increase height to match reference file
|
||||
width=244,
|
||||
positionAbsolute={"x": 0, "y": 0},
|
||||
selected=False,
|
||||
sourcePosition="right",
|
||||
targetPosition="left",
|
||||
data=code_node,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_template_node(node_desc: NodeDescription) -> CompleteNode:
|
||||
"""Create template node"""
|
||||
# Build variable list
|
||||
variables = []
|
||||
template_text = node_desc.template or ""
|
||||
|
||||
# Collect all node IDs referenced in the template
|
||||
referenced_nodes = set()
|
||||
for var in node_desc.variables or []:
|
||||
source_node = var.source_node or ""
|
||||
source_variable = var.source_variable or ""
|
||||
|
||||
variable = Variable(variable=var.name, value_selector=[source_node, source_variable])
|
||||
variables.append(variable)
|
||||
|
||||
if source_node:
|
||||
referenced_nodes.add(source_node)
|
||||
|
||||
# Modify variable reference format in the template
|
||||
# Replace {{#node_id.variable#}} with {{ variable }}
|
||||
if source_node and source_variable:
|
||||
template_text = template_text.replace(f"{{{{#{source_node}.{source_variable}#}}}}", f"{{ {var.name} }}")
|
||||
|
||||
# Check if a reference to the start node needs to be added
|
||||
# If the template contains a reference to the start node but the variable list does not have a corresponding variable # noqa: E501
|
||||
import re
|
||||
|
||||
start_node_refs = re.findall(r"{{#(\d+)\.([^#]+)#}}", template_text)
|
||||
|
||||
for node_id, var_name in start_node_refs:
|
||||
# Check if there is already a reference to this variable
|
||||
if not any(v.variable == var_name for v in variables):
|
||||
# Add reference to start node variable
|
||||
variable = Variable(variable=var_name, value_selector=[node_id, var_name])
|
||||
variables.append(variable)
|
||||
|
||||
# Modify variable reference format in the template
|
||||
template_text = template_text.replace(f"{{{{#{node_id}.{var_name}#}}}}", f"{{ {var_name} }}")
|
||||
|
||||
# Get all variable names
|
||||
var_names = [var.variable for var in variables]
|
||||
|
||||
# Simple and crude method: directly replace all possible variable reference formats
|
||||
for var_name in var_names:
|
||||
# Replace {var_name} with {{ var_name }}
|
||||
template_text = template_text.replace("{" + var_name + "}", "{{ " + var_name + " }}")
|
||||
# Replace { var_name } with {{ var_name }}
|
||||
template_text = template_text.replace("{ " + var_name + " }", "{{ " + var_name + " }}")
|
||||
# Replace {var_name } with {{ var_name }}
|
||||
template_text = template_text.replace("{" + var_name + " }", "{{ " + var_name + " }}")
|
||||
# Replace { var_name} with {{ var_name }}
|
||||
template_text = template_text.replace("{ " + var_name + "}", "{{ " + var_name + " }}")
|
||||
# Replace {{{ var_name }}} with {{ var_name }}
|
||||
template_text = template_text.replace("{{{ " + var_name + " }}}", "{{ " + var_name + " }}")
|
||||
# Replace {{{var_name}}} with {{ var_name }}
|
||||
template_text = template_text.replace("{{{" + var_name + "}}}", "{{ " + var_name + " }}")
|
||||
|
||||
# Use regular expression to replace all triple curly braces with double curly braces
|
||||
template_text = re.sub(r"{{{([^}]+)}}}", r"{{ \1 }}", template_text)
|
||||
|
||||
template_node = TemplateTransformNodeType(
|
||||
title=node_desc.title,
|
||||
desc=node_desc.description or "",
|
||||
type=BlockEnum.template_transform,
|
||||
template=template_text,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
return CompleteNode(
|
||||
id=node_desc.id,
|
||||
type="custom",
|
||||
position={"x": 0, "y": 0}, # Temporary position, will be updated later
|
||||
height=82, # Increase height to match reference file
|
||||
width=244,
|
||||
positionAbsolute={"x": 0, "y": 0},
|
||||
selected=False,
|
||||
sourcePosition="right",
|
||||
targetPosition="left",
|
||||
data=template_node,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_end_node(node_desc: NodeDescription) -> CompleteNode:
|
||||
"""Create end node"""
|
||||
# Build output variable list
|
||||
outputs = []
|
||||
for output in node_desc.outputs or []:
|
||||
variable = Variable(
|
||||
variable=output.name, value_selector=[output.source_node or "", output.source_variable or ""]
|
||||
)
|
||||
outputs.append(variable)
|
||||
|
||||
end_node = EndNodeType(
|
||||
title=node_desc.title, desc=node_desc.description or "", type=BlockEnum.end, outputs=outputs
|
||||
)
|
||||
|
||||
return CompleteNode(
|
||||
id=node_desc.id,
|
||||
type="custom",
|
||||
position={"x": 0, "y": 0}, # Temporary position, will be updated later
|
||||
height=90,
|
||||
width=244,
|
||||
positionAbsolute={"x": 0, "y": 0},
|
||||
selected=False,
|
||||
sourcePosition="right",
|
||||
targetPosition="left",
|
||||
data=end_node,
|
||||
)
|
||||
7
api/core/auto/workflow_generator/models/__init__.py
Normal file
7
api/core/auto/workflow_generator/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
模型包
|
||||
"""
|
||||
|
||||
from .workflow_description import ConnectionDescription, NodeDescription, WorkflowDescription
|
||||
|
||||
__all__ = ["ConnectionDescription", "NodeDescription", "WorkflowDescription"]
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Workflow Description Model
|
||||
Used to represent the simplified workflow description generated by large language models
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VariableDescription(BaseModel):
|
||||
"""Variable description"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
required: bool = True
|
||||
source_node: Optional[str] = None
|
||||
source_variable: Optional[str] = None
|
||||
|
||||
|
||||
class OutputDescription(BaseModel):
|
||||
"""Output description"""
|
||||
|
||||
name: str
|
||||
type: str = "string"
|
||||
description: Optional[str] = None
|
||||
source_node: Optional[str] = None
|
||||
source_variable: Optional[str] = None
|
||||
|
||||
|
||||
class NodeDescription(BaseModel):
|
||||
"""Node description"""
|
||||
|
||||
id: str
|
||||
type: str
|
||||
title: str
|
||||
description: Optional[str] = ""
|
||||
variables: Optional[list[VariableDescription]] = Field(default_factory=list)
|
||||
outputs: Optional[list[OutputDescription]] = Field(default_factory=list)
|
||||
|
||||
# LLM node specific fields
|
||||
system_prompt: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
provider: Optional[str] = "zhipuai"
|
||||
model: Optional[str] = "glm-4-flash"
|
||||
|
||||
# Code node specific fields
|
||||
code: Optional[str] = None
|
||||
|
||||
# Template node specific fields
|
||||
template: Optional[str] = None
|
||||
|
||||
# Reference to workflow description, used for node relationship analysis
|
||||
workflow_description: Optional["WorkflowDescription"] = Field(default=None, exclude=True)
|
||||
|
||||
class Config:
|
||||
exclude = {"workflow_description"}
|
||||
|
||||
|
||||
class ConnectionDescription(BaseModel):
|
||||
"""Connection description"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
|
||||
|
||||
class WorkflowDescription(BaseModel):
|
||||
"""Workflow description"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = ""
|
||||
nodes: list[NodeDescription]
|
||||
connections: list[ConnectionDescription]
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
# Add workflow description reference to each node
|
||||
for node in self.nodes:
|
||||
node.workflow_description = self
|
||||
16
api/core/auto/workflow_generator/utils/__init__.py
Normal file
16
api/core/auto/workflow_generator/utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
工具函数包
|
||||
"""
|
||||
|
||||
from .llm_client import LLMClient
|
||||
from .prompts import DEFAULT_MODEL_CONFIG, DEFAULT_SYSTEM_PROMPT, build_workflow_prompt
|
||||
from .type_mapper import map_string_to_var_type, map_var_type_to_input_type
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MODEL_CONFIG",
|
||||
"DEFAULT_SYSTEM_PROMPT",
|
||||
"LLMClient",
|
||||
"build_workflow_prompt",
|
||||
"map_string_to_var_type",
|
||||
"map_var_type_to_input_type",
|
||||
]
|
||||
142
api/core/auto/workflow_generator/utils/config_manager.py
Normal file
142
api/core/auto/workflow_generator/utils/config_manager.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Configuration Manager
|
||||
Used to manage configurations and prompts
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Configuration manager for managing configurations"""
|
||||
|
||||
def __init__(self, config_dir: str = "config"):
|
||||
"""
|
||||
Initialize configuration manager
|
||||
|
||||
Args:
|
||||
config_dir: Configuration directory path
|
||||
"""
|
||||
self.config_dir = Path(config_dir)
|
||||
self.config: dict[str, Any] = {}
|
||||
self.last_load_time: float = 0
|
||||
self.reload_interval: float = 60 # Check every 60 seconds
|
||||
self._load_config()
|
||||
|
||||
def _should_reload(self) -> bool:
|
||||
"""Check if configuration needs to be reloaded"""
|
||||
return time.time() - self.last_load_time > self.reload_interval
|
||||
|
||||
def _load_config(self) -> dict[str, Any]:
|
||||
"""Load configuration files"""
|
||||
default_config = self._load_yaml(self.config_dir / "default.yaml")
|
||||
custom_config = self._load_yaml(self.config_dir / "custom.yaml")
|
||||
self.config = self._deep_merge(default_config, custom_config)
|
||||
self.last_load_time = time.time()
|
||||
return self.config
|
||||
|
||||
@staticmethod
|
||||
def _load_yaml(path: Path) -> dict[str, Any]:
|
||||
"""Load YAML file"""
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except FileNotFoundError:
|
||||
print(f"Warning: Config file not found at {path}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"Error loading config file {path}: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _deep_merge(dict1: dict, dict2: dict) -> dict:
|
||||
"""Recursively merge two dictionaries"""
|
||||
result = dict1.copy()
|
||||
for key, value in dict2.items():
|
||||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||||
result[key] = ConfigManager._deep_merge(result[key], value)
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def get(self, *keys: str, default: Any = None) -> Any:
|
||||
"""
|
||||
Get configuration value
|
||||
|
||||
Args:
|
||||
*keys: Configuration key path
|
||||
default: Default value
|
||||
|
||||
Returns:
|
||||
Configuration value or default value
|
||||
"""
|
||||
if self._should_reload():
|
||||
self._load_config()
|
||||
|
||||
current = self.config
|
||||
for key in keys:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
return default
|
||||
return current
|
||||
|
||||
@property
|
||||
def workflow_generator(self) -> dict[str, Any]:
|
||||
"""Get workflow generator configuration"""
|
||||
return self.get("workflow_generator", default={})
|
||||
|
||||
@property
|
||||
def workflow_nodes(self) -> dict[str, Any]:
|
||||
"""Get workflow nodes configuration"""
|
||||
return self.get("workflow_nodes", default={})
|
||||
|
||||
@property
|
||||
def output(self) -> dict[str, Any]:
|
||||
"""Get output configuration"""
|
||||
return self.get("output", default={})
|
||||
|
||||
def get_output_path(self, filename: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get output file path
|
||||
|
||||
Args:
|
||||
filename: Optional filename, uses default filename from config if not specified
|
||||
|
||||
Returns:
|
||||
Complete output file path
|
||||
"""
|
||||
output_config = self.output
|
||||
output_dir = output_config.get("dir", "output/")
|
||||
output_filename = filename or output_config.get("filename", "generated_workflow.yml")
|
||||
return os.path.join(output_dir, output_filename)
|
||||
|
||||
def get_workflow_model(self, model_name: Optional[str] = None) -> dict[str, Any]:
|
||||
"""
|
||||
Get workflow generation model configuration
|
||||
|
||||
Args:
|
||||
model_name: Model name, uses default model if not specified
|
||||
|
||||
Returns:
|
||||
Model configuration
|
||||
"""
|
||||
models = self.workflow_generator.get("models", {})
|
||||
|
||||
if not model_name:
|
||||
model_name = models.get("default")
|
||||
|
||||
return models.get("available", {}).get(model_name, {})
|
||||
|
||||
def get_llm_node_config(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get LLM node configuration
|
||||
|
||||
Returns:
|
||||
LLM node configuration
|
||||
"""
|
||||
return self.workflow_nodes.get("llm", {})
|
||||
151
api/core/auto/workflow_generator/utils/debug_manager.py
Normal file
151
api/core/auto/workflow_generator/utils/debug_manager.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Debug Manager
|
||||
Used to manage debug information saving
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
|
||||
class DebugManager:
|
||||
"""Debug manager for managing debug information saving"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Singleton pattern"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: dict[str, Any] = {}, debug_enabled: bool = False):
|
||||
"""
|
||||
Initialize debug manager
|
||||
|
||||
Args:
|
||||
config: Debug configuration
|
||||
debug_enabled: Whether to enable debug mode
|
||||
"""
|
||||
# Avoid repeated initialization
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.config = config or {}
|
||||
self.debug_enabled = debug_enabled or self.config.get("enabled", False)
|
||||
self.debug_dir = self.config.get("dir", "debug/")
|
||||
self.save_options = self.config.get(
|
||||
"save_options", {"prompt": True, "response": True, "json": True, "workflow": True}
|
||||
)
|
||||
|
||||
# Generate run ID
|
||||
self.case_id = self._generate_case_id()
|
||||
self.case_dir = os.path.join(self.debug_dir, self.case_id)
|
||||
|
||||
# If debug is enabled, create debug directory
|
||||
if self.debug_enabled:
|
||||
os.makedirs(self.case_dir, exist_ok=True)
|
||||
print(f"Debug mode enabled, debug information will be saved to: {self.case_dir}")
|
||||
|
||||
def _generate_case_id(self) -> str:
|
||||
"""
|
||||
Generate run ID
|
||||
|
||||
Returns:
|
||||
Run ID
|
||||
"""
|
||||
# Use format from configuration to generate run ID
|
||||
format_str = self.config.get("case_id_format", "%Y%m%d_%H%M%S_%f")
|
||||
timestamp = datetime.datetime.now().strftime(format_str)
|
||||
|
||||
# Add random string
|
||||
random_str = str(uuid.uuid4())[:8]
|
||||
|
||||
return f"{timestamp}_{random_str}"
|
||||
|
||||
def save_text(self, content: str, filename: str, subdir: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Save text content to file
|
||||
|
||||
Args:
|
||||
content: Text content
|
||||
filename: File name
|
||||
subdir: Subdirectory name
|
||||
|
||||
Returns:
|
||||
Saved file path, returns None if debug is not enabled
|
||||
"""
|
||||
if not self.debug_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Determine save path
|
||||
save_dir = self.case_dir
|
||||
if subdir:
|
||||
save_dir = os.path.join(save_dir, subdir)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
file_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Save content
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
print(f"Debug information saved to: {file_path}")
|
||||
return file_path
|
||||
except Exception as e:
|
||||
print(f"Failed to save debug information: {e}")
|
||||
return None
|
||||
|
||||
def save_json(self, data: Union[dict, list], filename: str, subdir: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Save JSON data to file
|
||||
|
||||
Args:
|
||||
data: JSON data
|
||||
filename: File name
|
||||
subdir: Subdirectory name
|
||||
|
||||
Returns:
|
||||
Saved file path, returns None if debug is not enabled
|
||||
"""
|
||||
if not self.debug_enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Determine save path
|
||||
save_dir = self.case_dir
|
||||
if subdir:
|
||||
save_dir = os.path.join(save_dir, subdir)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
file_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Save content
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Debug information saved to: {file_path}")
|
||||
return file_path
|
||||
except Exception as e:
|
||||
print(f"Failed to save debug information: {e}")
|
||||
return None
|
||||
|
||||
def should_save(self, option: str) -> bool:
|
||||
"""
|
||||
Check if specified type of debug information should be saved
|
||||
|
||||
Args:
|
||||
option: Debug information type
|
||||
|
||||
Returns:
|
||||
Whether it should be saved
|
||||
"""
|
||||
if not self.debug_enabled:
|
||||
return False
|
||||
|
||||
return self.save_options.get(option, False)
|
||||
438
api/core/auto/workflow_generator/utils/llm_client.py
Normal file
438
api/core/auto/workflow_generator/utils/llm_client.py
Normal file
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
LLM Client
|
||||
Used to call LLM API
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from core.auto.workflow_generator.utils.debug_manager import DebugManager
|
||||
from core.auto.workflow_generator.utils.prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM Client"""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance, debug_manager: DebugManager):
|
||||
"""
|
||||
Initialize LLM client
|
||||
|
||||
Args:
|
||||
api_key: API key
|
||||
model: Model name
|
||||
api_base: API base URL
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
debug_manager: Debug manager
|
||||
"""
|
||||
|
||||
self.debug_manager = debug_manager or DebugManager()
|
||||
self.model_instance = model_instance
|
||||
|
||||
def generate(self, prompt: str) -> str:
|
||||
"""
|
||||
Generate text
|
||||
|
||||
Args:
|
||||
prompt: Prompt text
|
||||
|
||||
Returns:
|
||||
Generated text
|
||||
"""
|
||||
|
||||
# Save prompt
|
||||
if self.debug_manager.should_save("prompt"):
|
||||
self.debug_manager.save_text(prompt, "prompt.txt", "llm")
|
||||
|
||||
try:
|
||||
response = self.model_instance.invoke_llm(
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(content=DEFAULT_SYSTEM_PROMPT),
|
||||
UserPromptMessage(content=prompt),
|
||||
],
|
||||
model_parameters={"temperature": 0.7, "max_tokens": 4900},
|
||||
)
|
||||
content = ""
|
||||
for chunk in response:
|
||||
content += chunk.delta.message.content
|
||||
print(f"Generation complete, text length: {len(content)} characters")
|
||||
|
||||
# Save response
|
||||
if self.debug_manager.should_save("response"):
|
||||
self.debug_manager.save_text(content, "response.txt", "llm")
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
print(f"Error generating text: {e}")
|
||||
raise e
|
||||
|
||||
def extract_json(self, text: str) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON from text
|
||||
|
||||
Args:
|
||||
text: Text containing JSON
|
||||
|
||||
Returns:
|
||||
Extracted JSON object
|
||||
"""
|
||||
print("Starting JSON extraction from text...")
|
||||
|
||||
# Save original text
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_text(text, "original_text.txt", "json")
|
||||
|
||||
# Use regex to extract JSON part
|
||||
json_match = re.search(r"```json\n(.*?)\n```", text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
print("Successfully extracted JSON from code block")
|
||||
else:
|
||||
# Try to match code block without language identifier
|
||||
json_match = re.search(r"```\n(.*?)\n```", text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
print("Successfully extracted JSON from code block without language identifier")
|
||||
else:
|
||||
# Try to match JSON surrounded by curly braces
|
||||
json_match = re.search(r"(\{.*\})", text, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
print("Successfully extracted JSON from curly braces")
|
||||
else:
|
||||
# Try to parse entire text
|
||||
json_str = text
|
||||
print("No JSON code block found, attempting to parse entire text")
|
||||
|
||||
# Save extracted JSON string
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_text(json_str, "extracted_json.txt", "json")
|
||||
|
||||
# Try multiple methods to parse JSON
|
||||
try:
|
||||
# Try direct parsing
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Direct JSON parsing failed: {e}, attempting basic cleaning")
|
||||
try:
|
||||
# Try basic cleaning
|
||||
cleaned_text = self._clean_text(json_str)
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_text(cleaned_text, "cleaned_json_1.txt", "json")
|
||||
result = json.loads(cleaned_text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing after basic cleaning failed: {e}, attempting to fix common errors")
|
||||
try:
|
||||
# Try fixing common errors
|
||||
fixed_text = self._fix_json_errors(json_str)
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_text(fixed_text, "cleaned_json_2.txt", "json")
|
||||
result = json.loads(fixed_text)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing after fixing common errors failed: {e}, attempting aggressive cleaning")
|
||||
try:
|
||||
# Try aggressive cleaning
|
||||
aggressive_cleaned = self._aggressive_clean(json_str)
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_text(aggressive_cleaned, "cleaned_json_3.txt", "json")
|
||||
result = json.loads(aggressive_cleaned)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing after aggressive cleaning failed: {e}, attempting manual JSON extraction")
|
||||
# Try manual JSON structure extraction
|
||||
result = self._manual_json_extraction(json_str)
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_json(result, "manual_json.json", "json")
|
||||
|
||||
# Check for nested workflow structure
|
||||
if "workflow" in result and isinstance(result["workflow"], dict):
|
||||
print("Detected nested workflow structure, extracting top-level data")
|
||||
# Extract workflow name and description
|
||||
name = result.get("name", "Text Analysis Workflow")
|
||||
description = result.get("description", "")
|
||||
|
||||
# Extract nodes and connections
|
||||
nodes = result["workflow"].get("nodes", [])
|
||||
connections = []
|
||||
|
||||
# If there are connections, extract them
|
||||
if "connections" in result["workflow"]:
|
||||
connections = result["workflow"]["connections"]
|
||||
|
||||
# Build standard format workflow description
|
||||
result = {"name": name, "description": description, "nodes": nodes, "connections": connections}
|
||||
|
||||
# Save final parsed JSON
|
||||
if self.debug_manager.should_save("json"):
|
||||
self.debug_manager.save_json(result, "final_json.json", "json")
|
||||
|
||||
print(
|
||||
f"JSON parsing successful, contains {len(result.get('nodes', []))} nodes and {len(result.get('connections', []))} connections" # noqa: E501
|
||||
)
|
||||
return result
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
Clean text by removing non-JSON characters
|
||||
|
||||
Args:
|
||||
text: Text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned text
|
||||
"""
|
||||
print("Starting text cleaning...")
|
||||
# Remove characters that might cause JSON parsing to fail
|
||||
lines = text.split("\n")
|
||||
cleaned_lines = []
|
||||
|
||||
in_json = False
|
||||
for line in lines:
|
||||
if line.strip().startswith("{") or line.strip().startswith("["):
|
||||
in_json = True
|
||||
|
||||
if in_json:
|
||||
cleaned_lines.append(line)
|
||||
|
||||
if line.strip().endswith("}") or line.strip().endswith("]"):
|
||||
in_json = False
|
||||
|
||||
cleaned_text = "\n".join(cleaned_lines)
|
||||
print(f"Text cleaning complete, length before: {len(text)}, length after: {len(cleaned_text)}")
|
||||
return cleaned_text
|
||||
|
||||
def _fix_json_errors(self, text: str) -> str:
|
||||
"""
|
||||
Fix common JSON errors
|
||||
|
||||
Args:
|
||||
text: Text to fix
|
||||
|
||||
Returns:
|
||||
Fixed text
|
||||
"""
|
||||
print("Attempting to fix JSON errors...")
|
||||
|
||||
# Replace single quotes with double quotes
|
||||
text = re.sub(r"'([^']*)'", r'"\1"', text)
|
||||
|
||||
# Fix missing commas
|
||||
text = re.sub(r"}\s*{", "},{", text)
|
||||
text = re.sub(r"]\s*{", "],{", text)
|
||||
text = re.sub(r"}\s*\[", r"},\[", text)
|
||||
text = re.sub(r"]\s*\[", r"],\[", text)
|
||||
|
||||
# Fix extra commas
|
||||
text = re.sub(r",\s*}", "}", text)
|
||||
text = re.sub(r",\s*]", "]", text)
|
||||
|
||||
# Ensure property names have double quotes
|
||||
text = re.sub(r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', text)
|
||||
|
||||
return text
|
||||
|
||||
def _aggressive_clean(self, text: str) -> str:
|
||||
"""
|
||||
More aggressive text cleaning
|
||||
|
||||
Args:
|
||||
text: Text to clean
|
||||
|
||||
Returns:
|
||||
Cleaned text
|
||||
"""
|
||||
print("Using aggressive cleaning method...")
|
||||
|
||||
# Try to find outermost curly braces
|
||||
start_idx = text.find("{")
|
||||
end_idx = text.rfind("}")
|
||||
|
||||
if start_idx != -1 and end_idx != -1 and start_idx < end_idx:
|
||||
text = text[start_idx : end_idx + 1]
|
||||
|
||||
# Remove comments
|
||||
text = re.sub(r"//.*?\n", "\n", text)
|
||||
text = re.sub(r"/\*.*?\*/", "", text, flags=re.DOTALL)
|
||||
|
||||
# Fix JSON format
|
||||
text = self._fix_json_errors(text)
|
||||
|
||||
# Remove escape characters
|
||||
text = text.replace("\\n", "\n").replace("\\t", "\t").replace('\\"', '"')
|
||||
|
||||
# Fix potential Unicode escape issues
|
||||
text = re.sub(r"\\u([0-9a-fA-F]{4})", lambda m: chr(int(m.group(1), 16)), text)
|
||||
|
||||
return text
|
||||
|
||||
def _manual_json_extraction(self, text: str) -> dict[str, Any]:
|
||||
"""
|
||||
Manual JSON structure extraction
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
|
||||
Returns:
|
||||
Extracted JSON object
|
||||
"""
|
||||
print("Attempting manual JSON structure extraction...")
|
||||
|
||||
# Extract workflow name
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', text)
|
||||
name = name_match.group(1) if name_match else "Simple Workflow"
|
||||
|
||||
# Extract workflow description
|
||||
desc_match = re.search(r'"description"\s*:\s*"([^"]*)"', text)
|
||||
description = desc_match.group(1) if desc_match else "Automatically generated workflow"
|
||||
|
||||
# Extract nodes
|
||||
nodes = []
|
||||
node_matches = re.finditer(r'\{\s*"id"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', text)
|
||||
|
||||
for match in node_matches:
|
||||
node_id = match.group(1)
|
||||
node_type = match.group(2)
|
||||
|
||||
# Extract node title
|
||||
title_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"title"\s*:\s*"([^"]*)"', text, re.DOTALL)
|
||||
title = title_match.group(1) if title_match else f"{node_type.capitalize()} Node"
|
||||
|
||||
# Extract node description
|
||||
desc_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"description"\s*:\s*"([^"]*)"', text, re.DOTALL)
|
||||
desc = desc_match.group(1) if desc_match else ""
|
||||
|
||||
# Create basic node based on node type
|
||||
if node_type == "start":
|
||||
# Extract variables
|
||||
variables = []
|
||||
var_section_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"variables"\s*:\s*\[(.*?)\]', text, re.DOTALL)
|
||||
|
||||
if var_section_match:
|
||||
var_section = var_section_match.group(1)
|
||||
var_matches = re.finditer(r'\{\s*"name"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', var_section)
|
||||
|
||||
for var_match in var_matches:
|
||||
var_name = var_match.group(1)
|
||||
var_type = var_match.group(2)
|
||||
|
||||
# Extract variable description
|
||||
var_desc_match = re.search(
|
||||
rf'"name"\s*:\s*"{var_name}".*?"description"\s*:\s*"([^"]*)"', var_section, re.DOTALL
|
||||
)
|
||||
var_desc = var_desc_match.group(1) if var_desc_match else ""
|
||||
|
||||
# Extract required status
|
||||
var_required_match = re.search(
|
||||
rf'"name"\s*:\s*"{var_name}".*?"required"\s*:\s*(true|false)', var_section, re.DOTALL
|
||||
)
|
||||
var_required = var_required_match.group(1).lower() == "true" if var_required_match else True
|
||||
|
||||
variables.append(
|
||||
{"name": var_name, "type": var_type, "description": var_desc, "required": var_required}
|
||||
)
|
||||
|
||||
# If no variables found but this is a greeting workflow, add default user_name variable
|
||||
if not variables and ("greeting" in name.lower()):
|
||||
variables.append(
|
||||
{"name": "user_name", "type": "string", "description": "User's name", "required": True}
|
||||
)
|
||||
|
||||
nodes.append({"id": node_id, "type": "start", "title": title, "desc": desc, "variables": variables})
|
||||
elif node_type == "llm":
|
||||
# Extract system prompt
|
||||
system_prompt_match = re.search(
|
||||
rf'"id"\s*:\s*"{node_id}".*?"system_prompt"\s*:\s*"([^"]*)"', text, re.DOTALL
|
||||
)
|
||||
system_prompt = system_prompt_match.group(1) if system_prompt_match else "You are a helpful assistant"
|
||||
|
||||
# Extract user prompt
|
||||
user_prompt_match = re.search(
|
||||
rf'"id"\s*:\s*"{node_id}".*?"user_prompt"\s*:\s*"([^"]*)"', text, re.DOTALL
|
||||
)
|
||||
user_prompt = user_prompt_match.group(1) if user_prompt_match else "Please answer the user's question"
|
||||
|
||||
nodes.append(
|
||||
{
|
||||
"id": node_id,
|
||||
"type": "llm",
|
||||
"title": title,
|
||||
"desc": desc,
|
||||
"provider": "zhipuai",
|
||||
"model": "glm-4-flash",
|
||||
"system_prompt": system_prompt,
|
||||
"user_prompt": user_prompt,
|
||||
"variables": [],
|
||||
}
|
||||
)
|
||||
elif node_type in ("template", "template-transform"):
|
||||
# Extract template content
|
||||
template_match = re.search(rf'"id"\s*:\s*"{node_id}".*?"template"\s*:\s*"([^"]*)"', text, re.DOTALL)
|
||||
template = template_match.group(1) if template_match else ""
|
||||
|
||||
# Fix triple curly brace issue in template, replace {{{ with {{ and }}} with }}
|
||||
template = template.replace("{{{", "{{").replace("}}}", "}}")
|
||||
|
||||
nodes.append(
|
||||
{
|
||||
"id": node_id,
|
||||
"type": "template-transform",
|
||||
"title": title,
|
||||
"desc": desc,
|
||||
"template": template,
|
||||
"variables": [],
|
||||
}
|
||||
)
|
||||
elif node_type == "end":
|
||||
# Extract outputs
|
||||
outputs = []
|
||||
output_section_match = re.search(
|
||||
rf'"id"\s*:\s*"{node_id}".*?"outputs"\s*:\s*\[(.*?)\]', text, re.DOTALL
|
||||
)
|
||||
|
||||
if output_section_match:
|
||||
output_section = output_section_match.group(1)
|
||||
output_matches = re.finditer(
|
||||
r'\{\s*"name"\s*:\s*"([^"]*)"\s*,\s*"type"\s*:\s*"([^"]*)"', output_section
|
||||
)
|
||||
|
||||
for output_match in output_matches:
|
||||
output_name = output_match.group(1)
|
||||
output_type = output_match.group(2)
|
||||
|
||||
# Extract source node
|
||||
source_node_match = re.search(
|
||||
rf'"name"\s*:\s*"{output_name}".*?"source_node"\s*:\s*"([^"]*)"', output_section, re.DOTALL
|
||||
)
|
||||
source_node = source_node_match.group(1) if source_node_match else ""
|
||||
|
||||
# Extract source variable
|
||||
source_var_match = re.search(
|
||||
rf'"name"\s*:\s*"{output_name}".*?"source_variable"\s*:\s*"([^"]*)"',
|
||||
output_section,
|
||||
re.DOTALL,
|
||||
)
|
||||
source_var = source_var_match.group(1) if source_var_match else ""
|
||||
|
||||
outputs.append(
|
||||
{
|
||||
"name": output_name,
|
||||
"type": output_type,
|
||||
"source_node": source_node,
|
||||
"source_variable": source_var,
|
||||
}
|
||||
)
|
||||
|
||||
nodes.append({"id": node_id, "type": "end", "title": title, "desc": desc, "outputs": outputs})
|
||||
else:
|
||||
# Other node types
|
||||
nodes.append({"id": node_id, "type": node_type, "title": title, "desc": desc})
|
||||
|
||||
# Extract connections
|
||||
connections = []
|
||||
conn_matches = re.finditer(r'\{\s*"source"\s*:\s*"([^"]*)"\s*,\s*"target"\s*:\s*"([^"]*)"', text)
|
||||
|
||||
for match in conn_matches:
|
||||
connections.append({"source": match.group(1), "target": match.group(2)})
|
||||
|
||||
return {"name": name, "description": description, "nodes": nodes, "connections": connections}
|
||||
171
api/core/auto/workflow_generator/utils/prompts.py
Normal file
171
api/core/auto/workflow_generator/utils/prompts.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Prompt Template Collection
|
||||
Contains all prompt templates used for generating workflows
|
||||
"""
|
||||
|
||||
# Default model configuration
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
"provider": "zhipuai",
|
||||
"model": "glm-4-flash",
|
||||
"mode": "chat",
|
||||
"completion_params": {"temperature": 0.7},
|
||||
}
|
||||
|
||||
|
||||
# Default system prompt
|
||||
DEFAULT_SYSTEM_PROMPT = "You are a workflow design expert who can design Dify workflows based on user requirements."
|
||||
|
||||
|
||||
# Code node template
|
||||
CODE_NODE_TEMPLATE = """def main(input_var):
|
||||
# Process input variable
|
||||
result = input_var
|
||||
|
||||
# Return a dictionary; keys must exactly match variable names defined in outputs
|
||||
return {"output_var_name": result}"""
|
||||
|
||||
|
||||
def build_workflow_prompt(user_requirement: str) -> str:
|
||||
"""
|
||||
Build workflow generation prompt
|
||||
|
||||
Args:
|
||||
user_requirement: User requirement description
|
||||
|
||||
Returns:
|
||||
Prompt string
|
||||
"""
|
||||
# String concatenation to avoid brace escaping
|
||||
prompt_part1 = (
|
||||
"""
|
||||
Please design a Dify workflow based on the following user requirement:
|
||||
|
||||
User requirement: """
|
||||
+ user_requirement
|
||||
+ """
|
||||
|
||||
The description's language should align consistently with the user's requirements.
|
||||
|
||||
Generate a concise workflow description containing the following node types:
|
||||
- Start: Start node, defines workflow input parameters
|
||||
- LLM: Large Language Model node for text generation
|
||||
- Code: Code node to execute Python code
|
||||
- Template: Template node for formatting outputs
|
||||
- End: End node, defines workflow output
|
||||
|
||||
【Important Guidelines】:
|
||||
1. When referencing variables in LLM nodes, use the format {{#nodeID.variable_name#}}, e.g., {{#1740019130520.user_question#}}, where 1740019130520 is the source node ID. Otherwise, in most cases, the user prompt should define a template to guide the LLM’s response.
|
||||
2. Code nodes must define a `main` function that directly receives variables from upstream nodes as parameters; do not use template syntax inside the function.
|
||||
3. Dictionary keys returned by Code nodes must exactly match the variable names defined in outputs.
|
||||
4. Variables in Template nodes must strictly use double curly braces format "{{ variable_name }}"; note exactly two curly braces, neither one nor three. For example, "User question is: {{ user_question }}, answer: {{ answer }}". Triple curly braces such as "{{{ variable_name }}}" are strictly forbidden.
|
||||
5. IMPORTANT: In Code nodes, the function parameter names MUST EXACTLY MATCH the variable names defined in that Code node. For example, if a Code node defines a variable with name "input_text" that receives data from an upstream node, the function parameter must also be named "input_text" (e.g., def main(input_text): ...).
|
||||
6. CRITICAL: LLM nodes ALWAYS output their result in a variable named "text". When a Code node receives data from an LLM node, the source_variable MUST be "text". For example, if a Code node has a variable named "llm_output" that receives data from an LLM node, the source_variable should be "text", not "input_text" or any other name.
|
||||
|
||||
Return the workflow description in JSON format as follows:
|
||||
```json
|
||||
{
|
||||
"name": "Workflow Name",
|
||||
"description": "Workflow description",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "node1",
|
||||
"type": "start",
|
||||
"title": "Start Node",
|
||||
"description": "Description of the start node",
|
||||
"variables": [
|
||||
{
|
||||
"name": "variable_name",
|
||||
"type": "string|number",
|
||||
"description": "Variable description",
|
||||
"required": true|false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "node2",
|
||||
"type": "llm",
|
||||
"title": "LLM Node",
|
||||
"description": "Description of LLM node",
|
||||
"system_prompt": "System prompt",
|
||||
"user_prompt": "User prompt, variables referenced using {{#nodeID.variable_name#}}, e.g., {{#node1.variable_name#}}",
|
||||
"provider": "zhipuai",
|
||||
"model": "glm-4-flash",
|
||||
"variables": [
|
||||
{
|
||||
"name": "variable_name",
|
||||
"type": "string|number",
|
||||
"source_node": "node1",
|
||||
"source_variable": "variable_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "node3",
|
||||
"type": "code",
|
||||
"title": "Code Node",
|
||||
"description": "Description of the code node",
|
||||
"code": "def main(input_var):\n import re\n match = re.search(r'Result[::](.*?)(?=[.]|$)', input_var)\n result = match.group(1).strip() if match else 'Not found'\n return {'output': result}",
|
||||
"variables": [
|
||||
{
|
||||
"name": "input_var",
|
||||
"type": "string|number",
|
||||
"source_node": "node2",
|
||||
"source_variable": "text"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output_var_name",
|
||||
"type": "string|number|object"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "node4",
|
||||
"type": "template",
|
||||
"title": "Template Node",
|
||||
"description": "Description of the template node",
|
||||
"template": "Template content using double curly braces, e.g.: The result is: {{ result }}",
|
||||
"variables": [
|
||||
{
|
||||
"name": "variable_name",
|
||||
"type": "string|number",
|
||||
"source_node": "node3",
|
||||
"source_variable": "output_var_name"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "node5",
|
||||
"type": "end",
|
||||
"title": "End Node",
|
||||
"description": "Description of the end node",
|
||||
"outputs": [
|
||||
{
|
||||
"name": "output_variable_name",
|
||||
"type": "string|number",
|
||||
"source_node": "node4",
|
||||
"source_variable": "output"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"connections": [
|
||||
{"source": "node1", "target": "node2"},
|
||||
{"source": "node2", "target": "node3"},
|
||||
{"source": "node3", "target": "node4"},
|
||||
{"source": "node4", "target": "node5"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Ensure the workflow logic is coherent, node connections are correct, and variable passing is logical.
|
||||
Generate unique numeric IDs for each node, e.g., 1740019130520.
|
||||
Generate appropriate unique names for each variable across the workflow.
|
||||
Ensure all LLM nodes use provider "zhipuai" and model "glm-4-flash".
|
||||
|
||||
Note: LLM nodes usually return a long text; Code nodes typically require regex to extract relevant information.
|
||||
""" # noqa: E501
|
||||
)
|
||||
|
||||
return prompt_part1
|
||||
50
api/core/auto/workflow_generator/utils/type_mapper.py
Normal file
50
api/core/auto/workflow_generator/utils/type_mapper.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Type Mapping Utility
|
||||
Used to map string types to Dify types
|
||||
"""
|
||||
|
||||
from core.auto.node_types.common import InputVarType, VarType
|
||||
|
||||
|
||||
def map_var_type_to_input_type(var_type: str) -> InputVarType:
|
||||
"""
|
||||
Map variable type to input variable type
|
||||
|
||||
Args:
|
||||
var_type: Variable type string
|
||||
|
||||
Returns:
|
||||
Input variable type
|
||||
"""
|
||||
type_map = {
|
||||
"string": InputVarType.text_input,
|
||||
"number": InputVarType.number,
|
||||
"boolean": InputVarType.select,
|
||||
"object": InputVarType.json,
|
||||
"array": InputVarType.json,
|
||||
"file": InputVarType.file,
|
||||
}
|
||||
|
||||
return type_map.get(var_type.lower(), InputVarType.text_input)
|
||||
|
||||
|
||||
def map_string_to_var_type(type_str: str) -> VarType:
|
||||
"""
|
||||
Map string to variable type
|
||||
|
||||
Args:
|
||||
type_str: Type string
|
||||
|
||||
Returns:
|
||||
Variable type
|
||||
"""
|
||||
type_map = {
|
||||
"string": VarType.string,
|
||||
"number": VarType.number,
|
||||
"boolean": VarType.boolean,
|
||||
"object": VarType.object,
|
||||
"array": VarType.array,
|
||||
"file": VarType.file,
|
||||
}
|
||||
|
||||
return type_map.get(type_str.lower(), VarType.string)
|
||||
134
api/core/auto/workflow_generator/workflow.py
Normal file
134
api/core/auto/workflow_generator/workflow.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
from core.auto.node_types.common import CompleteEdge, CompleteNode
|
||||
|
||||
|
||||
class Workflow:
|
||||
"""
|
||||
Workflow class
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, nodes: list[CompleteNode], edges: list[CompleteEdge]):
|
||||
"""
|
||||
Initialize workflow
|
||||
|
||||
Args:
|
||||
name: Workflow name
|
||||
nodes: List of nodes
|
||||
edges: List of edges
|
||||
"""
|
||||
self.name = name
|
||||
self.nodes = nodes
|
||||
self.edges = edges
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert workflow to dictionary
|
||||
|
||||
Returns:
|
||||
Workflow dictionary
|
||||
"""
|
||||
# Apply basic information (fixed template)
|
||||
app_info = {
|
||||
"description": "",
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFEAD5",
|
||||
"mode": "workflow",
|
||||
"name": self.name,
|
||||
"use_icon_as_answer_icon": False,
|
||||
}
|
||||
|
||||
# Feature configuration (fixed template)
|
||||
features = {
|
||||
"file_upload": {
|
||||
"allowed_file_extensions": [".JPG", ".JPEG", ".PNG", ".GIF", ".WEBP", ".SVG"],
|
||||
"allowed_file_types": ["image"],
|
||||
"allowed_file_upload_methods": ["local_file", "remote_url"],
|
||||
"enabled": False,
|
||||
"fileUploadConfig": {
|
||||
"audio_file_size_limit": 50,
|
||||
"batch_count_limit": 5,
|
||||
"file_size_limit": 15,
|
||||
"image_file_size_limit": 10,
|
||||
"video_file_size_limit": 100,
|
||||
},
|
||||
"image": {"enabled": False, "number_limits": 3, "transfer_methods": ["local_file", "remote_url"]},
|
||||
"number_limits": 3,
|
||||
},
|
||||
"opening_statement": "",
|
||||
"retriever_resource": {"enabled": True},
|
||||
"sensitive_word_avoidance": {"enabled": False},
|
||||
"speech_to_text": {"enabled": False},
|
||||
"suggested_questions": [],
|
||||
"suggested_questions_after_answer": {"enabled": False},
|
||||
"text_to_speech": {"enabled": False, "language": "", "voice": ""},
|
||||
}
|
||||
|
||||
# View configuration (fixed template)
|
||||
viewport = {"x": 92.96659905656679, "y": 79.13437154762897, "zoom": 0.9002006986311041}
|
||||
|
||||
# Nodes and edges
|
||||
nodes_data = []
|
||||
for node in self.nodes:
|
||||
node_data = node.to_json()
|
||||
nodes_data.append(node_data)
|
||||
|
||||
edges_data = []
|
||||
for edge in self.edges:
|
||||
edge_data = edge.to_json()
|
||||
edges_data.append(edge_data)
|
||||
|
||||
# Build a complete workflow dictionary
|
||||
workflow_dict = {
|
||||
"app": app_info,
|
||||
"kind": "app",
|
||||
"version": "0.1.2",
|
||||
"workflow": {
|
||||
"conversation_variables": [],
|
||||
"environment_variables": [],
|
||||
"features": features,
|
||||
"graph": {"edges": edges_data, "nodes": nodes_data, "viewport": viewport},
|
||||
},
|
||||
}
|
||||
|
||||
return workflow_dict
|
||||
|
||||
def save_to_yaml(self, file_path: str):
|
||||
"""
|
||||
Save workflow to YAML file
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
"""
|
||||
workflow_dict = self.to_dict()
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(workflow_dict, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
print(f"Workflow saved to: {file_path}")
|
||||
|
||||
def save_to_json(self, file_path: str):
|
||||
"""
|
||||
Save workflow to JSON file
|
||||
|
||||
Args:
|
||||
file_path: File path
|
||||
"""
|
||||
workflow_dict = self.to_dict()
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(workflow_dict, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Workflow saved to: {file_path}")
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Convert workflow to YAML string
|
||||
|
||||
Returns:
|
||||
YAML string
|
||||
"""
|
||||
return yaml.dump(self.to_dict(), allow_unicode=True, sort_keys=False)
|
||||
159
api/core/auto/workflow_generator/workflow_generator.py
Normal file
159
api/core/auto/workflow_generator/workflow_generator.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Workflow Generator
|
||||
Used to generate Dify workflows based on user requirements
|
||||
"""
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.auto.workflow_generator.generators.edge_generator import EdgeGenerator
|
||||
from core.auto.workflow_generator.generators.layout_engine import LayoutEngine
|
||||
from core.auto.workflow_generator.generators.node_generator import NodeGenerator
|
||||
from core.auto.workflow_generator.models.workflow_description import WorkflowDescription
|
||||
from core.auto.workflow_generator.utils.config_manager import ConfigManager
|
||||
from core.auto.workflow_generator.utils.debug_manager import DebugManager
|
||||
from core.auto.workflow_generator.utils.llm_client import LLMClient
|
||||
from core.auto.workflow_generator.utils.prompts import build_workflow_prompt
|
||||
from core.auto.workflow_generator.workflow import Workflow
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class WorkflowGenerator:
|
||||
"""Workflow generator for creating Dify workflows based on user requirements"""
|
||||
|
||||
def __init__(self, model_instance: ModelInstance, config_dir: str = "config", debug_enabled: bool = False):
|
||||
"""
|
||||
Initialize workflow generator
|
||||
|
||||
Args:
|
||||
api_key: LLM API key
|
||||
config_dir: Configuration directory path
|
||||
model_name: Specified model name, uses default model if not specified
|
||||
debug_enabled: Whether to enable debug mode
|
||||
"""
|
||||
# Load configuration
|
||||
self.config = ConfigManager(config_dir)
|
||||
|
||||
# Initialize debug manager
|
||||
self.debug_manager = DebugManager(config=self.config.get("debug", default={}), debug_enabled=debug_enabled)
|
||||
|
||||
# Get model configuration
|
||||
|
||||
# Initialize LLM client
|
||||
self.llm_client = LLMClient(model_instance=model_instance, debug_manager=self.debug_manager)
|
||||
|
||||
def generate_workflow(self, user_requirement: str) -> str:
|
||||
"""
|
||||
Generate workflow based on user requirements
|
||||
|
||||
Args:
|
||||
user_requirement: User requirement description
|
||||
output_path: Output file path, uses default path from config if None
|
||||
|
||||
Returns:
|
||||
Generated workflow YAML file path
|
||||
"""
|
||||
print("\n===== Starting Workflow Generation =====")
|
||||
print(f"User requirement: {user_requirement}")
|
||||
|
||||
# Save user requirement
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
self.debug_manager.save_text(user_requirement, "user_requirement.txt", "workflow")
|
||||
|
||||
# Use default path from config if output path not specified
|
||||
|
||||
# Step 1: Generate simple workflow description
|
||||
print("\n----- Step 1: Generating Simple Workflow Description -----")
|
||||
workflow_description = self._generate_workflow_description(user_requirement)
|
||||
print(f"Workflow name: {workflow_description.name}")
|
||||
print(f"Workflow description: {workflow_description.description}")
|
||||
print(f"Number of nodes: {len(workflow_description.nodes)}")
|
||||
print(f"Number of connections: {len(workflow_description.connections)}")
|
||||
|
||||
# Save workflow description
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
self.debug_manager.save_json(workflow_description.dict(), "workflow_description.json", "workflow")
|
||||
|
||||
# Step 2: Parse description and generate nodes
|
||||
print("\n----- Step 2: Parsing Description, Generating Nodes -----")
|
||||
nodes = NodeGenerator.create_nodes(workflow_description.nodes)
|
||||
print(f"Generated nodes: {len(nodes)}")
|
||||
for i, node in enumerate(nodes):
|
||||
print(f"Node {i + 1}: ID={node.id}, Type={node.data.type.value}, Title={node.data.title}")
|
||||
|
||||
# Save node information
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
nodes_data = [node.dict() for node in nodes]
|
||||
self.debug_manager.save_json(nodes_data, "nodes.json", "workflow")
|
||||
|
||||
# Step 3: Generate edges
|
||||
print("\n----- Step 3: Generating Edges -----")
|
||||
edges = EdgeGenerator.create_edges(nodes, workflow_description.connections)
|
||||
print(f"Generated edges: {len(edges)}")
|
||||
for i, edge in enumerate(edges):
|
||||
print(f"Edge {i + 1}: ID={edge.id}, Source={edge.source}, Target={edge.target}")
|
||||
|
||||
# Save edge information
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
edges_data = [edge.dict() for edge in edges]
|
||||
self.debug_manager.save_json(edges_data, "edges.json", "workflow")
|
||||
|
||||
# Step 4: Apply layout
|
||||
print("\n----- Step 4: Applying Layout -----")
|
||||
LayoutEngine.apply_topological_layout(nodes, edges)
|
||||
print("Applied topological sort layout")
|
||||
|
||||
# Save nodes with layout
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
nodes_with_layout = [node.dict() for node in nodes]
|
||||
self.debug_manager.save_json(nodes_with_layout, "nodes_with_layout.json", "workflow")
|
||||
|
||||
# Step 5: Generate YAML
|
||||
print("\n----- Step 5: Generating YAML -----")
|
||||
workflow = Workflow(name=workflow_description.name, nodes=nodes, edges=edges)
|
||||
|
||||
# Ensure output directory exists
|
||||
|
||||
# Save as YAML
|
||||
|
||||
# Save final YAML
|
||||
print("\n===== Workflow Generation Complete =====")
|
||||
return workflow.to_yaml()
|
||||
|
||||
def _generate_workflow_description(self, user_requirement: str) -> WorkflowDescription:
|
||||
"""
|
||||
Generate simple workflow description using LLM
|
||||
|
||||
Args:
|
||||
user_requirement: User requirement description
|
||||
|
||||
Returns:
|
||||
Simple workflow description
|
||||
"""
|
||||
# Build prompt
|
||||
print("Building prompt...")
|
||||
prompt = build_workflow_prompt(user_requirement)
|
||||
|
||||
# Call LLM
|
||||
print("Calling LLM to generate workflow description...")
|
||||
response_text = self.llm_client.generate(prompt)
|
||||
|
||||
# Parse LLM response
|
||||
print("Parsing LLM response...")
|
||||
workflow_description_dict = self.llm_client.extract_json(response_text)
|
||||
|
||||
try:
|
||||
# Parse into WorkflowDescription object
|
||||
print("Converting JSON to WorkflowDescription object...")
|
||||
workflow_description = WorkflowDescription.parse_obj(workflow_description_dict)
|
||||
return workflow_description
|
||||
except ValidationError as e:
|
||||
# If parsing fails, print error and raise exception
|
||||
error_msg = f"Failed to parse workflow description: {e}"
|
||||
print(error_msg)
|
||||
|
||||
# Save error information
|
||||
if self.debug_manager.should_save("workflow"):
|
||||
self.debug_manager.save_text(str(e), "validation_error.txt", "workflow")
|
||||
self.debug_manager.save_json(workflow_description_dict, "invalid_workflow_description.json", "workflow")
|
||||
|
||||
raise ValueError(error_msg)
|
||||
@@ -1,9 +1,11 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetQuery, DocumentSegment
|
||||
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import DatasetRetrieverResource
|
||||
|
||||
|
||||
@@ -41,15 +43,29 @@ class DatasetIndexToolCallbackHandler:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -187,18 +187,30 @@ class ProviderConfiguration(BaseModel):
|
||||
:return:
|
||||
"""
|
||||
# get provider
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == ModelProviderID(self.provider.provider).plugin_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
model_provider_id = ModelProviderID(self.provider.provider)
|
||||
if model_provider_id.is_langgenius():
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
or_(
|
||||
Provider.provider_name == model_provider_id.provider_name,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider_record = (
|
||||
db.session.query(Provider)
|
||||
.filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||
Provider.provider_name == self.provider.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
|
||||
@@ -97,32 +97,18 @@ class File(BaseModel):
|
||||
return text
|
||||
|
||||
def generate_url(self) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
else:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.remote_url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@@ -5,6 +5,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.agent.plugin_entities import AgentStrategyProviderEntity
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
@@ -153,6 +154,8 @@ class GenericProviderID:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
if not value:
|
||||
raise NotFound("plugin not found, please add plugin")
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
@@ -164,6 +167,9 @@ class GenericProviderID:
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
def is_langgenius(self) -> bool:
|
||||
return self.organization == "langgenius"
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}"
|
||||
|
||||
@@ -61,7 +61,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
|
||||
SQL_CREATE_INDEX = """
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
|
||||
INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
|
||||
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER multilingual_lexer')
|
||||
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER world_lexer')
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
@@ -28,7 +29,7 @@ from core.rag.retrieval.router.multi_dataset_function_call_router import Functio
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import ChildChunk, Dataset, DatasetQuery, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
|
||||
@@ -429,16 +430,31 @@ class DatasetRetrieval:
|
||||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
dataset_document = DatasetDocument.query.filter(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
).first()
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = ChildChunk.query.filter(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
).first()
|
||||
if child_chunk:
|
||||
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ If you need to return a text message, you can use the following interface.
|
||||
If you need to return the raw data of a file, such as images, audio, video, PPT, Word, Excel, etc., you can use the following interface.
|
||||
|
||||
- `blob` The raw data of the file, of bytes type
|
||||
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `octet/stream` as the default type
|
||||
- `meta` The metadata of the file, if you know the type of the file, it is best to pass a `mime_type`, otherwise Dify will use `application/octet-stream` as the default type
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
|
||||
@@ -58,7 +58,7 @@ Difyは`テキスト`、`リンク`、`画像`、`ファイルBLOB`、`JSON`な
|
||||
画像、音声、動画、PPT、Word、Excelなどのファイルの生データを返す必要がある場合は、以下のインターフェースを使用できます。
|
||||
|
||||
- `blob` ファイルの生データ(bytes型)
|
||||
- `meta` ファイルのメタデータ。ファイルの種類が分かっている場合は、`mime_type`を渡すことをお勧めします。そうでない場合、Difyはデフォルトタイプとして`octet/stream`を使用します。
|
||||
- `meta` ファイルのメタデータ。ファイルの種類が分かっている場合は、`mime_type`を渡すことをお勧めします。そうでない場合、Difyはデフォルトタイプとして`application/octet-stream`を使用します。
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
|
||||
@@ -58,7 +58,7 @@ Dify支持`文本` `链接` `图片` `文件BLOB` `JSON` 等多种消息类型
|
||||
如果你需要返回文件的原始数据,如图片、音频、视频、PPT、Word、Excel等,可以使用以下接口。
|
||||
|
||||
- `blob` 文件的原始数据,bytes类型
|
||||
- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`,否则Dify将使用`octet/stream`作为默认类型
|
||||
- `meta` 文件的元数据,如果你知道该文件的类型,最好传递一个`mime_type`,否则Dify将使用`application/octet-stream`作为默认类型
|
||||
|
||||
```python
|
||||
def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
|
||||
|
||||
@@ -290,14 +290,16 @@ class ToolEngine:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "octet/stream"),
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "octet/stream") if response.meta else "octet/stream",
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class ToolFileManager:
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
mimetype = guess_type(file_url)[0] or "application/octet-stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
|
||||
@@ -765,17 +765,22 @@ class ToolManager:
|
||||
|
||||
@classmethod
|
||||
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
|
||||
return (
|
||||
dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
+ provider_id
|
||||
+ "/icon"
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
/ "current"
|
||||
/ "tool-provider"
|
||||
/ "builtin"
|
||||
/ provider_id
|
||||
/ "icon"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
|
||||
return str(
|
||||
URL(dify_config.CONSOLE_API_URL)
|
||||
URL(dify_config.CONSOLE_API_URL or "/")
|
||||
/ "console"
|
||||
/ "api"
|
||||
/ "workspaces"
|
||||
|
||||
@@ -58,7 +58,7 @@ class ToolFileMessageTransformer:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "octet/stream")
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
|
||||
|
||||
@@ -136,7 +136,7 @@ class ArrayStringSegment(ArraySegment):
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return json.dumps(self.value)
|
||||
return json.dumps(self.value, ensure_ascii=False)
|
||||
|
||||
|
||||
class ArrayNumberSegment(ArraySegment):
|
||||
|
||||
@@ -11,6 +11,10 @@ from core.workflow.graph_engine.entities.event import (
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
@@ -62,6 +66,12 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
elif isinstance(event, LoopRunStartedEvent):
|
||||
self.on_workflow_loop_started(event=event)
|
||||
elif isinstance(event, LoopRunNextEvent):
|
||||
self.on_workflow_loop_next(event=event)
|
||||
elif isinstance(event, LoopRunSucceededEvent | LoopRunFailedEvent):
|
||||
self.on_workflow_loop_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
@@ -160,6 +170,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
@@ -182,6 +194,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
if event.in_loop_id:
|
||||
self.print_text(f"Loop ID: {event.in_loop_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
@@ -213,6 +227,31 @@ class WorkflowLoggingCallback(WorkflowCallback):
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_started(self, event: LoopRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish loop started
|
||||
"""
|
||||
self.print_text("\n[LoopRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
|
||||
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
|
||||
"""
|
||||
Publish loop next
|
||||
"""
|
||||
self.print_text("\n[LoopRunNextEvent]", color="blue")
|
||||
self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
|
||||
self.print_text(f"Loop Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish loop completed
|
||||
"""
|
||||
self.print_text(
|
||||
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.loop_id}", color="blue")
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
|
||||
@@ -20,12 +20,15 @@ class NodeRunMetadataKey(StrEnum):
|
||||
AGENT_LOG = "agent_log"
|
||||
ITERATION_ID = "iteration_id"
|
||||
ITERATION_INDEX = "iteration_index"
|
||||
LOOP_ID = "loop_id"
|
||||
LOOP_INDEX = "loop_index"
|
||||
PARALLEL_ID = "parallel_id"
|
||||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
|
||||
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.nodes.base import BaseIterationState, BaseNode
|
||||
from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
@@ -41,11 +41,13 @@ class WorkflowRunState:
|
||||
class NodeRun(BaseModel):
|
||||
node_id: str
|
||||
iteration_node_id: str
|
||||
loop_node_id: str
|
||||
|
||||
workflow_node_runs: list[NodeRun]
|
||||
workflow_node_steps: int
|
||||
|
||||
current_iteration_state: Optional[BaseIterationState]
|
||||
current_loop_state: Optional[BaseLoopState]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -74,3 +76,4 @@ class WorkflowRunState:
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
||||
self.current_iteration_state = None
|
||||
self.current_loop_state = None
|
||||
|
||||
@@ -63,6 +63,8 @@ class BaseNodeEvent(GraphEngineEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
@@ -100,6 +102,10 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInLoopFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
@@ -122,6 +128,8 @@ class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: Optional[str] = None
|
||||
"""loop id if node is in loop"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
@@ -189,6 +197,59 @@ class IterationRunFailedEvent(BaseIterationEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Loop Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseLoopEvent(GraphEngineEvent):
|
||||
loop_id: str = Field(..., description="loop node execution id")
|
||||
loop_node_id: str = Field(..., description="loop node id")
|
||||
loop_node_type: NodeType = Field(..., description="node type, loop or loop")
|
||||
loop_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""loop run in parallel mode run id"""
|
||||
|
||||
|
||||
class LoopRunStartedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class LoopRunNextEvent(BaseLoopEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_loop_output: Optional[Any] = None
|
||||
duration: Optional[float] = None
|
||||
|
||||
|
||||
class LoopRunSucceededEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
loop_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class LoopRunFailedEvent(BaseLoopEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
metadata: Optional[Mapping[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Agent Events
|
||||
###########################################
|
||||
@@ -207,6 +268,7 @@ class AgentLogEvent(BaseAgentEvent):
|
||||
status: str = Field(..., description="status")
|
||||
data: Mapping[str, Any] = Field(..., description="data")
|
||||
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="metadata")
|
||||
node_id: str = Field(..., description="agent node id")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
|
||||
|
||||
@@ -18,7 +18,9 @@ from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunM
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseAgentEvent,
|
||||
BaseIterationEvent,
|
||||
BaseLoopEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
@@ -501,7 +503,7 @@ class GraphEngine:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if not isinstance(event, BaseAgentEvent) and event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
@@ -648,6 +650,12 @@ class GraphEngine:
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
elif isinstance(item, BaseLoopEvent):
|
||||
# add parallel info to loop event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
yield item
|
||||
else:
|
||||
|
||||
@@ -158,6 +158,7 @@ class AnswerStreamGeneratorRouter:
|
||||
NodeType.IF_ELSE,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
NodeType.ITERATION,
|
||||
NodeType.LOOP,
|
||||
NodeType.VARIABLE_ASSIGNER,
|
||||
}
|
||||
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
|
||||
|
||||
@@ -35,7 +35,7 @@ class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if event.in_iteration_id or event.in_loop_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .node import BaseNode
|
||||
|
||||
__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"]
|
||||
__all__ = [
|
||||
"BaseIterationNodeData",
|
||||
"BaseIterationState",
|
||||
"BaseLoopNodeData",
|
||||
"BaseLoopState",
|
||||
"BaseNode",
|
||||
"BaseNodeData",
|
||||
]
|
||||
|
||||
@@ -147,3 +147,18 @@ class BaseIterationState(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
|
||||
|
||||
class BaseLoopNodeData(BaseNodeData):
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class BaseLoopState(BaseModel):
|
||||
loop_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
||||
|
||||
@@ -33,7 +33,7 @@ class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if event.in_iteration_id or event.in_loop_id:
|
||||
if self.has_output and event.node_id not in self.output_node_ids:
|
||||
event.chunk_content = "\n" + event.chunk_content
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class NodeType(StrEnum):
|
||||
VARIABLE_AGGREGATOR = "variable-aggregator"
|
||||
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
|
||||
LOOP = "loop"
|
||||
LOOP_START = "loop-start"
|
||||
ITERATION = "iteration"
|
||||
ITERATION_START = "iteration-start" # Fake start node for iteration.
|
||||
PARAMETER_EXTRACTOR = "parameter-extractor"
|
||||
|
||||
@@ -120,6 +120,10 @@ class Response:
|
||||
if disp_type == "attachment" or filename is not None:
|
||||
return True
|
||||
|
||||
# For 'text/' types, only 'csv' should be downloaded as file
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
# For application types, try to detect if it's a text-based format
|
||||
if content_type.startswith("application/"):
|
||||
# Common text-based application types
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from .entities import LoopNodeData
|
||||
from .loop_node import LoopNode
|
||||
from .loop_start_node import LoopStartNode
|
||||
|
||||
__all__ = ["LoopNode", "LoopNodeData", "LoopStartNode"]
|
||||
|
||||
@@ -1,13 +1,54 @@
|
||||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNodeData(BaseIterationNodeData):
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
"""
|
||||
|
||||
loop_count: int # Maximum number of loops
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
|
||||
class LoopState(BaseIterationState):
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
"""
|
||||
Loop Start Node Data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LoopState(BaseLoopState):
|
||||
"""
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
||||
|
||||
@@ -1,9 +1,35 @@
|
||||
from typing import Any
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.variables import IntegerSegment
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
LoopRunFailedEvent,
|
||||
LoopRunNextEvent,
|
||||
LoopRunStartedEvent,
|
||||
LoopRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopNode(BaseNode[LoopNodeData]):
|
||||
@@ -14,24 +40,323 @@ class LoopNode(BaseNode[LoopNodeData]):
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self) -> LoopState: # type: ignore
|
||||
return super()._run() # type: ignore
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""Run the node."""
|
||||
# Get inputs
|
||||
loop_count = self.node_data.loop_count
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
|
||||
|
||||
# Initialize graph
|
||||
loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
# Initialize variable pool
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
variable_pool.add([self.node_id, "index"], 0)
|
||||
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=loop_graph,
|
||||
graph_config=self.graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
# Start Loop event
|
||||
yield LoopRunStartedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={"loop_length": loop_count},
|
||||
predecessor_node_id=self.previous_node_id,
|
||||
)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
try:
|
||||
check_break_result = False
|
||||
for i in range(loop_count):
|
||||
# Run workflow
|
||||
rst = graph_engine.run()
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"loop {self.node_id} current index not found")
|
||||
current_index = current_index_variable.value
|
||||
|
||||
check_break_result = False
|
||||
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_loop_id:
|
||||
event.in_loop_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.LOOP_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event=event, iter_run_index=current_index)
|
||||
|
||||
# Check if all variables in break conditions exist
|
||||
exists_variable = False
|
||||
for condition in break_conditions:
|
||||
if not self.graph_runtime_state.variable_pool.get(condition.variable_selector):
|
||||
exists_variable = False
|
||||
break
|
||||
else:
|
||||
exists_variable = True
|
||||
if exists_variable:
|
||||
input_conditions, group_result, check_break_result = condition_processor.process_conditions(
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
conditions=break_conditions,
|
||||
operator=logical_operator,
|
||||
)
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
# Loop run failed
|
||||
yield event
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=i,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield self._handle_event_metadata(event=cast(InNodeEvent, event), iter_run_index=current_index)
|
||||
|
||||
# Remove all nodes outputs from variable pool
|
||||
for node_id in loop_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
if check_break_result:
|
||||
break
|
||||
|
||||
# Move to next loop
|
||||
next_index = current_index + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
yield LoopRunNextEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_loop_output=None,
|
||||
)
|
||||
|
||||
# Loop completed successfully
|
||||
yield LoopRunSucceededEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "loop_break" if check_break_result else "loop_completed",
|
||||
},
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Loop failed
|
||||
logger.exception("Loop run failed")
|
||||
yield LoopRunFailedEvent(
|
||||
loop_id=self.id,
|
||||
loop_node_id=self.node_id,
|
||||
loop_node_type=self.node_type,
|
||||
loop_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
steps=loop_count,
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens,
|
||||
"completed_reason": "error",
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens},
|
||||
)
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
variable_pool.remove([self.node_id, "index"])
|
||||
|
||||
def _handle_event_metadata(
|
||||
self,
|
||||
*,
|
||||
event: BaseNodeEvent | InNodeEvent,
|
||||
iter_run_index: int,
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
if NodeRunMetadataKey.LOOP_ID not in metadata:
|
||||
metadata = {
|
||||
**metadata,
|
||||
NodeRunMetadataKey.LOOP_ID: self.node_id,
|
||||
NodeRunMetadataKey.LOOP_INDEX: iter_run_index,
|
||||
}
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Get conditions.
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_id = node_config.get("id")
|
||||
if not node_id:
|
||||
return []
|
||||
variable_mapping = {}
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [
|
||||
Condition( # type: ignore
|
||||
variable_selector=[node_id, "index"],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[],
|
||||
)
|
||||
]
|
||||
# init graph
|
||||
loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
|
||||
|
||||
if not loop_graph:
|
||||
raise ValueError("loop graph not found")
|
||||
|
||||
for sub_node_id, sub_node_config in loop_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
node_type = NodeType(sub_node_config.get("data", {}).get("type"))
|
||||
if node_type not in NODE_TYPE_CLASSES_MAPPING:
|
||||
continue
|
||||
node_version = sub_node_config.get("data", {}).get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
# remove loop variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + "." + key: value
|
||||
for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
# remove variable out from loop
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
20
api/core/workflow/nodes/loop/loop_start_node.py
Normal file
20
api/core/workflow/nodes/loop/loop_start_node.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.loop.entities import LoopStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class LoopStartNode(BaseNode):
|
||||
"""
|
||||
Loop Start Node.
|
||||
"""
|
||||
|
||||
_node_data_cls = LoopStartNodeData
|
||||
_node_type = NodeType.LOOP_START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
@@ -13,6 +13,7 @@ from core.workflow.nodes.iteration import IterationNode, IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.list_operator import ListOperatorNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.loop import LoopNode, LoopStartNode
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.start import StartNode
|
||||
@@ -85,6 +86,14 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
|
||||
LATEST_VERSION: IterationStartNode,
|
||||
"1": IterationStartNode,
|
||||
},
|
||||
NodeType.LOOP: {
|
||||
LATEST_VERSION: LoopNode,
|
||||
"1": LoopNode,
|
||||
},
|
||||
NodeType.LOOP_START: {
|
||||
LATEST_VERSION: LoopStartNode,
|
||||
"1": LoopStartNode,
|
||||
},
|
||||
NodeType.PARAMETER_EXTRACTOR: {
|
||||
LATEST_VERSION: ParameterExtractorNode,
|
||||
"1": ParameterExtractorNode,
|
||||
|
||||
@@ -338,6 +338,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
|
||||
@@ -7,7 +7,7 @@ import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
@@ -158,6 +158,39 @@ def _build_from_remote_url(
|
||||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if upload_file_id:
|
||||
try:
|
||||
uuid.UUID(upload_file_id)
|
||||
except ValueError:
|
||||
raise ValueError("Invalid upload file id format")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
upload_file = db.session.scalar(stmt)
|
||||
if upload_file is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
|
||||
file_type = FileType(mapping.get("type", "custom"))
|
||||
file_type = _standardize_file_type(
|
||||
file_type, extension="." + upload_file.extension, mime_type=upload_file.mime_type
|
||||
)
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
url = mapping.get("url") or mapping.get("remote_url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
@@ -17,8 +17,8 @@ workflow_app_log_partial_fields = {
|
||||
|
||||
workflow_app_log_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(workflow_app_log_partial_fields), attribute="items"),
|
||||
"has_more": fields.Boolean,
|
||||
"data": fields.List(fields.Nested(workflow_app_log_partial_fields)),
|
||||
}
|
||||
|
||||
@@ -45,7 +45,9 @@ workflow_fields = {
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
"features": fields.Raw(attribute="features_dict"),
|
||||
"hash": fields.String(attribute="unique_hash"),
|
||||
"version": fields.String(attribute="version"),
|
||||
"version": fields.String,
|
||||
"marked_name": fields.String,
|
||||
"marked_comment": fields.String,
|
||||
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account"),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||
|
||||
@@ -77,7 +77,7 @@ def login_required(func):
|
||||
)
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = Account.query.filter_by(id=ta.account_id).first()
|
||||
account = db.session.query(Account).filter_by(id=ta.account_id).first()
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user