Compare commits

...

80 Commits
0.7.3 ... 0.8.0

Author SHA1 Message Date
-LAN-
ffd4bf8bf0 fix(docker/docker-compose.yaml): Set default value for REDIS_SENTINEL_SOCKET_TIMEOUT and CELERY_SENTINEL_SOCKET_TIMEOUT (#8218) 2024-09-10 18:47:59 +08:00
Jyong
bb3002b173 revert page column (#8217) 2024-09-10 18:21:22 +08:00
Yi Xiao
d4dc54447a fix the tooltip in tool nodes (#8215) 2024-09-10 17:53:44 +08:00
Bowen Liang
d109881410 chore(api/models): apply ruff reformatting (#7600)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-09-10 17:08:06 +08:00
Joel
d1605952b0 fix: input chat input wrong padding (#8207) 2024-09-10 17:01:32 +08:00
Bowen Liang
2cf1187b32 chore(api/core): apply ruff reformatting (#7624) 2024-09-10 17:00:20 +08:00
github-actions[bot]
178730266d chore: translate i18n files (#8202)
Co-authored-by: takatost <5485478+takatost@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-09-10 16:13:26 +08:00
takatost
dabfd74622 feat: Parallel Execution of Nodes in Workflows (#8192)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: Yi <yxiaoisme@gmail.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-09-10 15:23:16 +08:00
crazywoola
5da0182800 docs: replace docker-compose with docker compose (#8195) 2024-09-10 15:02:52 +08:00
-LAN-
ed37439ef7 refactor(api/core): Improve type hints and apply ruff formatter in agent runner and model manager. (#8166) 2024-09-10 15:00:25 +08:00
Jyong
af92f19291 filter excel empty sheet (#8194) 2024-09-10 14:55:08 +08:00
zhuhao
86f7f245e4 fix: The length of the tag should between 1 and 50 (#8187) (#8188) 2024-09-10 14:07:06 +08:00
Jyong
2d690801d1 nvidia rerank top n missed (#8185) 2024-09-10 13:17:48 +08:00
Nam Vu
fede54be77 fix: Version '2.6.2-2' for 'expat' was not found (#8182) 2024-09-10 13:00:37 +08:00
Jyong
85ff82a694 code merge error (#8183)
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-10 12:52:50 +08:00
ice yao
c8df92d0eb add volcengine tos storage (#8164) 2024-09-10 09:19:47 +08:00
Bowen Liang
144d30d7ef chore: bump super-linter to v7 (#8148) 2024-09-10 09:13:48 +08:00
-LAN-
4313d92e6b feat(api/core/model_runtime/entities/defaults.py): Add TOP_K in default parameters. (#8167) 2024-09-10 09:11:31 +08:00
Nam Vu
0695543f63 Fix variable typo (cont) (#8161) 2024-09-09 23:46:13 +08:00
crazywoola
0bec6a037c update qwen-long (#8157) 2024-09-09 19:09:42 +08:00
Chenhe Gu
3ff9a1f24a Update LICENSE - remove 'SaaS' from restriction term definition (#8143) 2024-09-09 16:52:55 +08:00
crazywoola
a771eea4f6 fix: html raw render (#8138) 2024-09-09 16:12:59 +08:00
github-actions[bot]
61a0ca9e0d chore: translate i18n files (#8135)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-09 15:54:00 +08:00
cr-zhichen
551b33c8e5 fix: user-select style and pre-create iframe in embed.js (#8093) 2024-09-09 15:40:56 +08:00
AAEE86
fa34b9aed6 Modify model parameters in Spark LLMs and zhipuai LLMs (#8078)
Co-authored-by: Charlie.Wei <luowei@cvte.com>
2024-09-09 15:36:47 +08:00
zxhlyh
bbb609179f chore: offline n to 1 retrieval (#8134) 2024-09-09 15:32:02 +08:00
crazywoola
a27d4d58ec fix: ollama text embedding 500 error (#8131) 2024-09-09 15:27:49 +08:00
zhuhao
50d92f0fd4 add dify-sandbox health check in docker-compose.yaml (#8121) (#8124) 2024-09-09 14:39:06 +08:00
邹成卓
a15791e788 Fix: tongyi code wrapper works not stable (#7871)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-09 11:15:17 +08:00
ybalbert001
954580a4af feat: support more model types and builtin tools on aws/sagemaker (#8061)
Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
2024-09-09 10:34:11 +08:00
crazywoola
ab7d79275e fix: Claude can not validate credientials (#8109) 2024-09-09 10:22:42 +08:00
Thales Salazar
d3658166fb Translate billing to PT-BR (#8105)
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-09 10:16:22 +08:00
Su Yang
54b72bdd0a chore: keep dify compose file consistent format (#8102) 2024-09-09 08:30:03 +08:00
呆萌闷油瓶
d28446301f feat:add fishaudio in xinference (#8100) 2024-09-08 23:58:02 +08:00
crazywoola
9050f92e5b fix: parameter input (#8076) 2024-09-08 15:43:55 +08:00
Charlie.Wei
feefeb44d7 fix LangSmith project config error (#7996) 2024-09-08 13:25:27 +08:00
Zhi
d542b15cc0 feat: support redis sentinel mode (#7756) 2024-09-08 13:23:51 +08:00
Nam Vu
2d7954c7da Fix variable typo (#8084) 2024-09-08 13:14:11 +08:00
crazywoola
b1918dae5e fix: knowledge input (#8065) 2024-09-07 17:53:39 +08:00
Nam Vu
031a0b576d fix: i18n typo (#8077) 2024-09-07 16:59:38 +08:00
AAEE86
0cef25ef8c Revert "fix: parameter rule" (#8070) 2024-09-07 10:44:56 +08:00
Yi Xiao
cdb08be951 fix: overflow issues in chat history (#8062) 2024-09-06 19:20:18 +08:00
crazywoola
900fd82a92 fix: parameter rule (#8064) 2024-09-06 19:15:24 +08:00
crazywoola
44f963f281 If else add regexmatch (#8059)
Co-authored-by: 罗威 <luowei@cvte.com>
2024-09-06 18:35:51 +08:00
Charlie.Wei
01858e1caf ifEsle node add regex match (#8007) 2024-09-06 17:44:09 +08:00
ChengZi
2060db8e11 fix: change milvus init args from (host, port) to (url, token) (#8019)
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
2024-09-06 17:32:48 +08:00
Nam Vu
9ded063417 chore: #7348, support query conversations by updated_at (#8047) 2024-09-06 17:31:51 +08:00
Yi Xiao
d72da2777c fix the tooltip in tools node (#8055) 2024-09-06 17:28:22 +08:00
tmuife
89aede80cc Add OCI(Oracle Cloud Infrastructure) Generative AI Service as a Model Provider (#7775)
Co-authored-by: Walter Jin <jinshuhaicc@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: walter from vm <walter.jin@oracle.com>
2024-09-06 14:15:40 +08:00
zhuhao
e0d3cd91c6 support huawei cloud obs storage (#7980) (#7981) 2024-09-06 14:00:47 +08:00
winsonwhe
1a054ac1f4 Update milvus-standalone version and expose required ports for the container. (#7709)
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-06 12:01:59 +08:00
Charlie.Wei
3230f4a0ec Message rendering (#6868)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-09-05 21:00:09 +08:00
Designerxsh
dadca0f91a Fix/datasets api description error (#8025) 2024-09-05 16:45:44 +08:00
Byeongjin Kang
d489b8b3e0 feat: return page number of pdf documents upon retrieval (#7749) 2024-09-05 16:43:26 +08:00
Leng Yue
bd0992275c feat: support fish audio TTS (#7982) 2024-09-05 14:18:39 +08:00
非法操作
3e7597f2bd feat: add gpt-4o-2024-08-06 and json_schema for azure openAI service (#7648) 2024-09-04 21:56:08 +08:00
Jyong
0e71f6db84 fix spliter length missed (#7987) 2024-09-04 21:47:12 +08:00
wochuideng
f6b9982c23 Concurrent calls to the Wenxin model, and the exception problem when obtaining the token is fixed (#7976)
Co-authored-by: puqs1 <puqs1@lenovo.com>
2024-09-04 21:44:57 +08:00
github-actions[bot]
fb113a9479 chore: translate i18n files (#7965)
Co-authored-by: JohnJyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-04 17:45:12 +08:00
Jyong
15791510c8 fix wrong error message (#7972) 2024-09-04 16:46:41 +08:00
非法操作
0f72a8e89d chore: refactor the beichuan model (#7953) 2024-09-04 16:22:31 +08:00
KVOJJJin
14af87527f Feat:remove estimation of embedding cost (#7950)
Co-authored-by: jyong <718720800@qq.com>
2024-09-04 14:41:47 +08:00
zhuhao
83e84865be feat: add health check for pg and redis in docker-compose.middleware.yaml (#7961) (#7962) 2024-09-04 14:25:46 +08:00
Joel
c2a3c5a748 fix: get commit sha failed in translate action (#7959) 2024-09-04 13:13:21 +08:00
呆萌闷油瓶
83494cb4f5 fix:empty voice occurs when xinference CosyVoice tts model (#7958) 2024-09-04 13:04:31 +08:00
Vico Chu
0bc19c3fbf Feat: update app published time after clicking publish button (#7801) 2024-09-04 13:03:06 +08:00
Sumkor
571415d1a4 fix: split text keep separator (#7930) 2024-09-04 12:59:10 +08:00
Yuki Oshima
7b2cf8215f chore: fix inverted index japanese translation (#7957) 2024-09-04 12:44:59 +08:00
Joe
fee4d3f6ca feat: ops trace add llm model (#7306) 2024-09-04 10:39:00 +08:00
takatost
161cc0cda9 Revert "fix: an issue of keyword search feature in application log list" (#7949) 2024-09-04 10:00:55 +08:00
Nam Vu
71bff9fcf3 chore: #7943 i18n (#7948) 2024-09-04 09:42:25 +08:00
陳鈞
80d14c9b22 fix(api): Code-Based Extension cause error on position map sorting (#7934)
Signed-off-by: 陳鈞 <jim60105@gmail.com>
2024-09-04 08:41:12 +08:00
crazywoola
c5bdf08558 Chore/add roadmap (#7943) 2024-09-04 08:33:02 +08:00
crazywoola
596f160a1e Chore/add default step 1x url (#7933) 2024-09-04 08:32:22 +08:00
Jyong
d8b6c053a2 fix rerank model value is empty string (#7937) 2024-09-03 21:25:21 +08:00
Nam Vu
4b262cae58 chore: #7603 i18n (#7931) 2024-09-03 19:19:52 +08:00
Jyong
1a5116cba0 Fix/segment create with api (#7928) 2024-09-03 18:14:47 +08:00
Jyong
01581dd35f improve the notion table extract (#7925) 2024-09-03 17:52:07 +08:00
Joel
7fdd964379 fix: frontend handle sometimes server not generate the wrong follow up data struct (#7916) 2024-09-03 14:09:46 +08:00
Joel
0cfcc97e9d feat: support auto generate i18n translate (#6964)
Co-authored-by: crazywoola <427733928@qq.com>
2024-09-03 10:17:05 +08:00
1329 changed files with 43959 additions and 30747 deletions

View File

@@ -20,7 +20,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: api/**
@@ -66,7 +66,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: web/**
@@ -97,7 +97,7 @@ jobs:
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v44
uses: tj-actions/changed-files@v45
with:
files: |
**.sh
@@ -107,7 +107,7 @@ jobs:
dev/**
- name: Super-linter
uses: super-linter/super-linter/slim@v6
uses: super-linter/super-linter/slim@v7
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@@ -0,0 +1,54 @@
name: Check i18n Files and Create PR
on:
pull_request:
types: [closed]
branches: [main]
jobs:
check-and-update:
if: github.event.pull_request.merged == true
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 2 # last 2 commits
- name: Check for file changes in i18n/en-US
id: check_files
run: |
recent_commit_sha=$(git rev-parse HEAD)
second_recent_commit_sha=$(git rev-parse HEAD~1)
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v2
with:
node-version: 'lts/*'
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
run: yarn install --frozen-lockfile
- name: Run npm script
if: env.FILES_CHANGED == 'true'
run: npm run auto-gen-i18n
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
commit-message: Update i18n files based on en-US changes
title: 'chore: translate i18n files'
body: This PR was automatically created to update i18n files based on changes in en-US locale.
branch: chore/automated-i18n-updates

View File

@@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.

View File

@@ -39,7 +39,7 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, azure-blob, google-storage
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
S3_USE_AWS_MANAGED_IAM=false
@@ -73,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
TENCENT_COS_REGION=your-region
TENCENT_COS_SCHEME=your-scheme
# Huawei OBS Storage Configuration
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
HUAWEI_OBS_SECRET_KEY=your-secret-key
HUAWEI_OBS_ACCESS_KEY=your-access-key
HUAWEI_OBS_SERVER=your-server-url
# OCI Storage configuration
OCI_ENDPOINT=your-endpoint
OCI_BUCKET_NAME=your-bucket-name
@@ -80,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
OCI_SECRET_KEY=your-secret-key
OCI_REGION=your-region
# Volcengine tos Storage configuration
VOLCENGINE_TOS_ENDPOINT=your-endpoint
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
VOLCENGINE_TOS_REGION=your-region
# CORS configuration
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
@@ -101,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
QDRANT_GRPC_PORT=6334
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# MyScale configuration
MYSCALE_HOST=127.0.0.1

View File

@@ -55,7 +55,7 @@ RUN apt-get update \
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
&& apt-get update \
# For Security
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
&& apt-get autoremove -y \
&& rm -rf /var/lib/apt/lists/*

View File

@@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
"""
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="endpoint URL of code execution servcie",
description="endpoint URL of code execution service",
default="http://sandbox:8194",
)
@@ -415,7 +415,7 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
default=None,
)

View File

@@ -1,7 +1,7 @@
from typing import Any, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
@@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
@@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
default=None,
)
CELERY_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
description="Redis Sentinel master name",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)
@computed_field
@property
def CELERY_RESULT_BACKEND(self) -> str | None:
@@ -184,6 +201,8 @@ class MiddlewareConfig(
AzureBlobStorageConfig,
GoogleCloudStorageConfig,
TencentCloudCOSStorageConfig,
HuaweiCloudOBSStorageConfig,
VolcengineTOSStorageConfig,
S3StorageConfig,
OCIStorageConfig,
# configs of vdb and vdb providers

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
from pydantic_settings import BaseSettings
@@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
description="whether to use SSL for Redis connection",
default=False,
)
REDIS_USE_SENTINEL: Optional[bool] = Field(
description="Whether to use Redis Sentinel mode",
default=False,
)
REDIS_SENTINELS: Optional[str] = Field(
description="Redis Sentinel nodes",
default=None,
)
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
description="Redis Sentinel service name",
default=None,
)
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
description="Redis Sentinel username",
default=None,
)
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
description="Redis Sentinel password",
default=None,
)
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Redis Sentinel socket timeout",
default=0.1,
)

View File

@@ -0,0 +1,29 @@
from typing import Optional
from pydantic import BaseModel, Field
class HuaweiCloudOBSStorageConfig(BaseModel):
"""
Huawei Cloud OBS storage configs
"""
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
description="Huawei Cloud OBS bucket name",
default=None,
)
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Access key",
default=None,
)
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
description="Huawei Cloud OBS Secret key",
default=None,
)
HUAWEI_OBS_SERVER: Optional[str] = Field(
description="Huawei Cloud OBS server URL",
default=None,
)

View File

@@ -0,0 +1,34 @@
from typing import Optional
from pydantic import BaseModel, Field
class VolcengineTOSStorageConfig(BaseModel):
"""
Volcengine tos storage configs
"""
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
description="Volcengine TOS Bucket Name",
default=None,
)
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
description="Volcengine TOS Access Key",
default=None,
)
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
description="Volcengine TOS Secret Key",
default=None,
)
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
description="Volcengine TOS Endpoint URL",
default=None,
)
VOLCENGINE_TOS_REGION: Optional[str] = Field(
description="Volcengine TOS Region",
default=None,
)

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import Field, PositiveInt
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
Milvus configs
"""
MILVUS_HOST: Optional[str] = Field(
description="Milvus host",
default=None,
MILVUS_URI: Optional[str] = Field(
description="Milvus uri",
default="http://127.0.0.1:19530",
)
MILVUS_PORT: PositiveInt = Field(
description="Milvus RestFul API port",
default=9091,
MILVUS_TOKEN: Optional[str] = Field(
description="Milvus token",
default=None,
)
MILVUS_USER: Optional[str] = Field(
@@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
default=None,
)
MILVUS_SECURE: bool = Field(
description="whether to use SSL connection for Milvus",
default=False,
)
MILVUS_DATABASE: str = Field(
description="Milvus database, default to `default`",
default="default",

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.7.3",
default="0.8.0",
)
COMMIT_SHA: str = Field(

File diff suppressed because one or more lines are too long

View File

@@ -173,18 +173,21 @@ class ChatConversationApi(Resource):
if args["keyword"]:
keyword_filter = "%{}%".format(args["keyword"])
message_subquery = (
db.session.query(Message.conversation_id)
.filter(or_(Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter)))
.subquery()
)
query = query.join(subquery, subquery.c.conversation_id == Conversation.id).filter(
or_(
Conversation.id.in_(message_subquery),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
query = (
query.join(
Message,
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
.filter(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
),
)
)
account = current_user
@@ -198,7 +201,11 @@ class ChatConversationApi(Resource):
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at >= start_datetime_utc)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at >= start_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
@@ -207,7 +214,11 @@ class ChatConversationApi(Resource):
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
query = query.where(Conversation.created_at < end_datetime_utc)
match args["sort_by"]:
case "updated_at" | "-updated_at":
query = query.where(Conversation.updated_at <= end_datetime_utc)
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
if args["annotation_status"] == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join(

View File

@@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.retrieval.retrival_methods import RetrievalMethod
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields

View File

@@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
@@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
model_manager.get_default_model_instance(

View File

@@ -13,7 +13,7 @@ from services.tag_service import TagService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name

View File

@@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
raise NotFound("Document is not completed.")
if not document.enabled:
raise NotFound("Document is disabled.")
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
@@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segemtns is required"}, 400
return {"error": "Segments is required"}, 400
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""

View File

@@ -1 +1 @@
import core.moderation.base
import core.moderation.base

View File

@@ -1,6 +1,7 @@
import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timezone
from typing import Optional, Union, cast
@@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
class BaseAgentRunner(AppRunner):
def __init__(self, tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None
) -> None:
def __init__(
self,
tenant_id: str,
application_generate_entity: AgentChatAppGenerateEntity,
conversation: Conversation,
app_config: AgentChatAppConfig,
model_config: ModelConfigWithCredentialsEntity,
config: AgentEntity,
queue_manager: AppQueueManager,
message: Message,
user_id: str,
memory: Optional[TokenBufferMemory] = None,
prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None,
) -> None:
"""
Agent runner
:param tenant_id: tenant id
@@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
self.message = message
self.user_id = user_id
self.memory = memory
self.history_prompt_messages = self.organize_agent_history(
prompt_messages=prompt_messages or []
)
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
self.variables_pool = variables_pool
self.db_variables_pool = db_variables
self.model_instance = model_instance
@@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback
hit_callback=hit_callback,
)
# get how many agent thoughts have been created
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
MessageAgentThought.message_id == self.message.id,
).count()
self.agent_thought_count = (
db.session.query(MessageAgentThought)
.filter(
MessageAgentThought.message_id == self.message.id,
)
.count()
)
db.session.close()
# check if model supports stream tool call
@@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
self.query = None
self._current_thoughts: list[PromptMessage] = []
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
-> AgentChatAppGenerateEntity:
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
"""
Repack app generate entity
"""
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
return app_generate_entity
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
"""
convert tool to prompt message tool
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
invoke_from=self.application_generate_entity.invoke_from
invoke_from=self.application_generate_entity.invoke_from,
)
tool_entity.load_variables(self.variables_pool)
@@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
}
},
)
parameters = tool_entity.get_all_runtime_parameters()
@@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
message_tool.parameters['properties'][parameter.name] = {
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
message_tool.parameters["required"].append(parameter.name)
return message_tool, tool_entity
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
"""
convert dataset retriever tool to prompt message tool
@@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
"type": "object",
"properties": {},
"required": [],
}
},
)
for parameter in tool.get_runtime_parameters():
parameter_type = 'string'
prompt_tool.parameters['properties'][parameter.name] = {
parameter_type = "string"
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
}
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
"""
Init tools
"""
@@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options]
prompt_tool.parameters['properties'][parameter.name] = {
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
"description": parameter.llm_description or "",
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.name not in prompt_tool.parameters["required"]:
prompt_tool.parameters["required"].append(parameter.name)
return prompt_tool
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
def create_agent_thought(
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
) -> MessageAgentThought:
"""
Create agent thought
"""
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
thought='',
thought="",
tool=tool_name,
tool_labels_str='{}',
tool_meta_str='{}',
tool_labels_str="{}",
tool_meta_str="{}",
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_files=json.dumps(messages_ids) if messages_ids else '',
answer='',
observation='',
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
tokens=0,
total_price=0,
position=self.agent_thought_count + 1,
currency='USD',
currency="USD",
latency=0,
created_by_role='account',
created_by_role="account",
created_by=self.user_id,
)
@@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
return thought
def save_agent_thought(self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None) -> MessageAgentThought:
def save_agent_thought(
self,
agent_thought: MessageAgentThought,
tool_name: str,
tool_input: Union[str, dict],
thought: str,
observation: Union[str, dict],
tool_invoke_meta: Union[str, dict],
answer: str,
messages_ids: list[str],
llm_usage: LLMUsage = None,
) -> MessageAgentThought:
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).filter(
MessageAgentThought.id == agent_thought.id
).first()
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
if thought is not None:
agent_thought.thought = thought
@@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
observation = json.dumps(observation, ensure_ascii=False)
except Exception as e:
observation = json.dumps(observation)
agent_thought.observation = observation
if answer is not None:
@@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
if messages_ids is not None and len(messages_ids) > 0:
agent_thought.message_files = json.dumps(messages_ids)
if llm_usage:
agent_thought.message_token = llm_usage.prompt_tokens
agent_thought.message_price_unit = llm_usage.prompt_price_unit
@@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(';') if agent_thought.tool else []
tools = agent_thought.tool.split(";") if agent_thought.tool else []
for tool in tools:
if not tool:
continue
@@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
labels[tool] = {"en_US": tool, "zh_Hans": tool}
agent_thought.tool_labels_str = json.dumps(labels)
@@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
db.session.commit()
db.session.close()
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
"""
convert tool variables to db variables
"""
db_variables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
).first()
db_variables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == self.message.conversation_id,
)
.first()
)
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
@@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
messages: list[Message] = db.session.query(Message).filter(
Message.conversation_id == self.message.conversation_id,
).order_by(Message.created_at.asc()).all()
messages: list[Message] = (
db.session.query(Message)
.filter(
Message.conversation_id == self.message.conversation_id,
)
.order_by(Message.created_at.asc())
.all()
)
for message in messages:
if message.id == self.message.id:
@@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(';')
tools = tools.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
tool_inputs = { tool: {} for tool in tools }
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
@@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(AssistantPromptMessage.ToolCall(
id=tool_call_id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool,
arguments=json.dumps(tool_inputs.get(tool, {})),
),
)
))
tool_call_response.append(ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))
)
tool_call_response.append(
ToolPromptMessage(
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
)
)
result.extend([
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response
])
result.extend(
[
AssistantPromptMessage(
content=agent_thought.thought,
tool_calls=tool_calls,
),
*tool_call_response,
]
)
if not tools:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
@@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.transform_message_files(
files,
file_extra_config
)
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
else:
file_objs = []

View File

@@ -25,17 +25,19 @@ from models.model import Message
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ['wenxin']
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage] = None
_agent_scratchpad: list[AgentScratchpadUnit] = None
_instruction: str = None
_query: str = None
_prompt_messages_tools: list[PromptMessage] = None
def run(self, message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
def run(
self,
message: Message,
query: str,
inputs: dict[str, str],
) -> Union[Generator, LLMResult]:
"""
Run Cot agent application
"""
@@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
trace_manager = app_generate_entity.trace_manager
# check model mode
if 'Observation' not in app_generate_entity.model_conf.stop:
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append('Observation')
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template
self._instruction = self._fill_in_inputs_from_external_data_tools(
instruction, inputs)
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
@@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
@@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
raise ValueError("failed to invoke llm")
usage_dict = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(
chunks, usage_dict)
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response='',
thought='',
action_str='',
observation='',
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
@@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=chunk
),
usage=None
)
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
scratchpad.thought = scratchpad.thought.strip(
) or 'I am thinking about how to help you'
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# get llm usage
if 'usage' in usage_dict:
increase_usage(llm_usage, usage_dict['usage'])
if "usage" in usage_dict:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict['usage'] = LLMUsage.empty_usage()
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name if scratchpad.action else '',
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input
} if scratchpad.action else {},
tool_name=scratchpad.action.action_name if scratchpad.action else "",
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought,
observation='',
observation="",
answer=scratchpad.agent_response,
messages_ids=[],
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ''
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(
scratchpad.action.action_input)
final_answer = json.dumps(scratchpad.action.action_input)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
except json.JSONDecodeError:
final_answer = f'{scratchpad.action.action_input}'
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
@@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.save_agent_thought(
agent_thought=agent_thought,
tool_name=scratchpad.action.action_name,
tool_input={
scratchpad.action.action_name: scratchpad.action.action_input},
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought,
observation={
scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict['usage']
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
@@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(
content=final_answer
),
usage=llm_usage['usage']
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint=''
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
tool_name='',
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[]
messages_ids=[],
)
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None
) -> tuple[str, ToolInvokeMeta]:
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: dict[str, Tool],
message_file_ids: list[str],
trace_manager: Optional[TraceQueueManager] = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
@@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# publish files
for message_file_id, save_as in message_files:
if save_as:
self.variables_pool.set_file(
tool_name=tool_call_name, value=message_file_id, name=save_as)
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
@@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(
action_name=action['action'],
action_input=action['action_input']
)
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
"""
@@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception as e:
continue
@@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
format assistant message
"""
message = ''
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
@@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
return message
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
@@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
if not current_scratchpad:
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or 'I am thinking about how to help you',
action_str='',
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
@@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(
message.tool_calls[0].function.arguments)
)
current_scratchpad.action_str = json.dumps(
current_scratchpad.action.to_dict()
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except:
pass
elif isinstance(message, ToolPromptMessage):
@@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
current_scratchpad.observation = message.content
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(
content=self._format_assistant_message(scratchpads)
))
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt \
.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
@@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content='')
assistant_message = AssistantPromptMessage(content="")
for unit in agent_scratchpad:
if unit.is_final():
assistant_message.content += f"Final Answer: {unit.agent_response}"
@@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages

View File

@@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
prompt_entity = self.app_config.agent.prompt
first_prompt = prompt_entity.first_prompt
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
@@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ''
assistant_prompt = ""
for unit in agent_scratchpad:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
@@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
query_prompt = f"Question: {self._query}"
# join all messages
prompt = system_prompt \
.replace("{{historic_messages}}", historic_prompt) \
.replace("{{agent_scratchpad}}", assistant_prompt) \
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]
return [UserPromptMessage(content=prompt)]

View File

@@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
"""
Agent Tool Entity.
"""
provider_type: Literal["builtin", "api", "workflow"]
provider_id: str
tool_name: str
@@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
"""
Agent Prompt Entity.
"""
first_prompt: str
next_iteration: str
@@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
"""
Action Entity.
"""
action_name: str
action_input: Union[dict, str]
@@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
Convert to dictionary.
"""
return {
'action': self.action_name,
'action_input': self.action_input,
"action": self.action_name,
"action_input": self.action_input,
}
agent_response: Optional[str] = None
@@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
Check if the scratchpad unit is final.
"""
return self.action is None or (
'final' in self.action.action_name.lower() and
'answer' in self.action.action_name.lower()
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
)
class AgentEntity(BaseModel):
"""
Agent Entity.
@@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
"""
Agent Strategy.
"""
CHAIN_OF_THOUGHT = 'chain-of-thought'
FUNCTION_CALLING = 'function-calling'
CHAIN_OF_THOUGHT = "chain-of-thought"
FUNCTION_CALLING = "function-calling"
provider: str
model: str

View File

@@ -24,11 +24,9 @@ from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self,
message: Message, query: str, **kwargs: Any
) -> Generator[LLMResultChunk, None, None]:
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
@@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# continue to run until there is not any tool call
function_call_state = True
llm_usage = {
'usage': None
}
final_answer = ''
llm_usage = {"usage": None}
final_answer = ""
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
if not final_llm_usage_dict['usage']:
final_llm_usage_dict['usage'] = usage
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict['usage']
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.prompt_price += usage.prompt_price
@@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
message_file_ids = []
agent_thought = self.create_agent_thought(
message_id=message.id,
message='',
tool_name='',
tool_input='',
messages_ids=message_file_ids
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
@@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ''
response = ""
# save tool call names and inputs
tool_call_names = ''
tool_call_inputs = ''
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
@@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
@@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result))
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
}, ensure_ascii=False)
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except json.JSONDecodeError as e:
# ensure ascii to avoid encoding error
tool_call_inputs = json.dumps({
tool_call[1]: tool_call[2] for tool_call in tool_calls
})
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
@@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
response += result.message.content
if not result.message.content:
result.message.content = ''
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
@@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
index=0,
message=result.message,
usage=result.usage,
)
),
)
assistant_message = AssistantPromptMessage(
content='',
tool_calls=[]
)
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
if tool_calls:
assistant_message.tool_calls=[
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type='function',
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1],
arguments=json.dumps(tool_call[2], ensure_ascii=False)
)
) for tool_call in tool_calls
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
else:
assistant_message.content = response
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
@@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage
llm_usage=current_llm_usage,
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
final_answer += response + '\n'
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# call tools
tool_responses = []
@@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
@@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file_id
), PublishFrom.APPLICATION_MANAGER)
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict()
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response['tool_response'] is not None:
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=tool_response['tool_response'],
content=tool_response["tool_response"],
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought=agent_thought,
agent_thought=agent_thought,
tool_name=None,
tool_input=None,
thought=None,
thought=None,
tool_invoke_meta={
tool_response['tool_call_name']: tool_response['meta']
for tool_response in tool_responses
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response['tool_call_name']: tool_response['tool_response']
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer=None,
messages_ids=message_file_ids
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
)
self.queue_manager.publish(QueueAgentThoughtEvent(
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
# update prompt tool
for prompt_tool in prompt_messages_tools:
@@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.update_db_variables(self.variables_pool, self.db_variables_pool)
# publish end event
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(
content=final_answer
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
system_fingerprint="",
)
),
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
system_fingerprint=''
)), PublishFrom.APPLICATION_MANAGER)
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
@@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
@@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
def extract_tool_calls(
self, llm_result_chunk: LLMResultChunk
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract tool calls from llm result chunk
@@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
"""
Extract blocking tool calls from llm result
@@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != '':
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append((
prompt_message.id,
prompt_message.function.name,
args,
))
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _init_system_message(
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
) -> list[PromptMessage]:
"""
Initialize system message
"""
@@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
@@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
@@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = '\n'.join([
content.data if content.type == PromptMessageContentType.TEXT else
'[image]' if content.type == PromptMessageContentType.IMAGE else
'[file]'
for content in prompt_message.content
])
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query, [])
@@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory
memory=self.memory,
).get_prompt()
prompt_messages = [
*self.history_prompt_messages,
*query_prompt_messages,
*self._current_thoughts
]
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)

View File

@@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
class CotAgentOutputParser:
@classmethod
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def handle_react_stream_output(
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
def parse_action(json_str):
try:
action = json.loads(json_str)
@@ -22,7 +23,7 @@ class CotAgentOutputParser:
action = action[0]
for key, value in action.items():
if 'input' in key.lower():
if "input" in key.lower():
action_input = value
else:
action_name = value
@@ -33,37 +34,37 @@ class CotAgentOutputParser:
action_input=action_input,
)
else:
return json_str or ''
return json_str or ""
except:
return json_str or ''
return json_str or ""
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
if not code_blocks:
return
for block in code_blocks:
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
yield parse_action(json_text)
code_block_cache = ''
code_block_cache = ""
code_block_delimiter_count = 0
in_code_block = False
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
got_json = False
action_cache = ''
action_str = 'action:'
action_cache = ""
action_str = "action:"
action_idx = 0
thought_cache = ''
thought_str = 'thought:'
thought_cache = ""
thought_str = "thought:"
thought_idx = 0
for response in llm_response:
if response.delta.usage:
usage_dict['usage'] = response.delta.usage
usage_dict["usage"] = response.delta.usage
response = response.delta.message.content
if not isinstance(response, str):
continue
@@ -72,24 +73,24 @@ class CotAgentOutputParser:
index = 0
while index < len(response):
steps = 1
delta = response[index:index+steps]
last_character = response[index-1] if index > 0 else ''
delta = response[index : index + steps]
last_character = response[index - 1] if index > 0 else ""
if delta == '`':
if delta == "`":
code_block_cache += delta
code_block_delimiter_count += 1
else:
if not in_code_block:
if code_block_delimiter_count > 0:
yield code_block_cache
code_block_cache = ''
code_block_cache = ""
else:
code_block_cache += delta
code_block_delimiter_count = 0
if not in_code_block and not in_json:
if delta.lower() == action_str[action_idx] and action_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@@ -97,7 +98,7 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
@@ -105,18 +106,18 @@ class CotAgentOutputParser:
action_cache += delta
action_idx += 1
if action_idx == len(action_str):
action_cache = ''
action_cache = ""
action_idx = 0
index += steps
continue
else:
if action_cache:
yield action_cache
action_cache = ''
action_cache = ""
action_idx = 0
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
if last_character not in ['\n', ' ', '']:
if last_character not in ["\n", " ", ""]:
index += steps
yield delta
continue
@@ -124,7 +125,7 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
@@ -132,31 +133,31 @@ class CotAgentOutputParser:
thought_cache += delta
thought_idx += 1
if thought_idx == len(thought_str):
thought_cache = ''
thought_cache = ""
thought_idx = 0
index += steps
continue
else:
if thought_cache:
yield thought_cache
thought_cache = ''
thought_cache = ""
thought_idx = 0
if code_block_delimiter_count == 3:
if in_code_block:
yield from extra_json_from_code_block(code_block_cache)
code_block_cache = ''
code_block_cache = ""
in_code_block = not in_code_block
code_block_delimiter_count = 0
if not in_code_block:
# handle single json
if delta == '{':
if delta == "{":
json_quote_count += 1
in_json = True
json_cache += delta
elif delta == '}':
elif delta == "}":
json_cache += delta
if json_quote_count > 0:
json_quote_count -= 1
@@ -172,12 +173,12 @@ class CotAgentOutputParser:
if got_json:
got_json = False
yield parse_action(json_cache)
json_cache = ''
json_cache = ""
json_quote_count = 0
in_json = False
if not in_code_block and not in_json:
yield delta.replace('`', '')
yield delta.replace("`", "")
index += steps
@@ -186,4 +187,3 @@ class CotAgentOutputParser:
if json_cache:
yield parse_action(json_cache)

View File

@@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
REACT_PROMPT_TEMPLATES = {
'english': {
'chat': {
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
"english": {
"chat": {
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
},
"completion": {
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
},
'completion': {
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
}
}
}
}

View File

@@ -26,34 +26,24 @@ class BaseAppConfigManager:
config_dict = dict(config_dict.items())
additional_features = AppAdditionalFeatures()
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
config=config_dict
)
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
additional_features.file_upload = FileUploadConfigManager.convert(
config=config_dict,
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
)
additional_features.opening_statement, additional_features.suggested_questions = \
OpeningStatementConfigManager.convert(
config=config_dict
)
additional_features.opening_statement, additional_features.suggested_questions = (
OpeningStatementConfigManager.convert(config=config_dict)
)
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
config=config_dict
)
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
config=config_dict
)
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
config=config_dict
)
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
return additional_features

View File

@@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
class SensitiveWordAvoidanceConfigManager:
@classmethod
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
if not sensitive_word_avoidance_dict:
return None
if sensitive_word_avoidance_dict.get('enabled'):
if sensitive_word_avoidance_dict.get("enabled"):
return SensitiveWordAvoidanceEntity(
type=sensitive_word_avoidance_dict.get('type'),
config=sensitive_word_avoidance_dict.get('config'),
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config"),
)
else:
return None
@classmethod
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
-> tuple[dict, list[str]]:
def validate_and_set_defaults(
cls, tenant_id, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {
"enabled": False
}
config["sensitive_word_avoidance"] = {"enabled": False}
if not isinstance(config["sensitive_word_avoidance"], dict):
raise ValueError("sensitive_word_avoidance must be of dict type")
@@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
typ = config["sensitive_word_avoidance"]["type"]
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
ModerationFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=sensitive_word_avoidance_config
)
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
return config, ["sensitive_word_avoidance"]

View File

@@ -12,67 +12,70 @@ class AgentConfigManager:
:param config: model config args
"""
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode']:
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
agent_dict = config.get('agent_mode', {})
agent_strategy = agent_dict.get('strategy', 'cot')
if agent_strategy == 'function_call':
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy == 'cot' or agent_strategy == 'react':
elif agent_strategy == "cot" or agent_strategy == "react":
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config['model']['provider'] == 'openai':
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get('tools', []):
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) >= 4:
if "enabled" not in tool or not tool["enabled"]:
continue
agent_tool_properties = {
'provider_type': tool['provider_type'],
'provider_id': tool['provider_id'],
'tool_name': tool['tool_name'],
'tool_parameters': tool.get('tool_parameters', {})
"provider_type": tool["provider_type"],
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
if 'strategy' in config['agent_mode'] and \
config['agent_mode']['strategy'] not in ['react_router', 'router']:
agent_prompt = agent_dict.get('prompt', None) or {}
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
"react_router",
"router",
]:
agent_prompt = agent_dict.get("prompt", None) or {}
# check model mode
model_mode = config.get('model', {}).get('mode', 'completion')
if model_mode == 'completion':
model_mode = config.get("model", {}).get("mode", "completion")
if model_mode == "completion":
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['completion'][
'agent_scratchpad']),
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
),
)
else:
agent_prompt_entity = AgentPromptEntity(
first_prompt=agent_prompt.get('first_prompt',
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
next_iteration=agent_prompt.get('next_iteration',
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
first_prompt=agent_prompt.get(
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
),
next_iteration=agent_prompt.get(
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
),
)
return AgentEntity(
provider=config['model']['provider'],
model=config['model']['name'],
provider=config["model"]["provider"],
model=config["model"]["name"],
strategy=strategy,
prompt=agent_prompt_entity,
tools=agent_tools,
max_iteration=agent_dict.get('max_iteration', 5)
max_iteration=agent_dict.get("max_iteration", 5),
)
return None

View File

@@ -15,39 +15,38 @@ class DatasetConfigManager:
:param config: model config args
"""
dataset_ids = []
if 'datasets' in config.get('dataset_configs', {}):
datasets = config.get('dataset_configs', {}).get('datasets', {
'strategy': 'router',
'datasets': []
})
if "datasets" in config.get("dataset_configs", {}):
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
for dataset in datasets.get('datasets', []):
for dataset in datasets.get("datasets", []):
keys = list(dataset.keys())
if len(keys) == 0 or keys[0] != 'dataset':
if len(keys) == 0 or keys[0] != "dataset":
continue
dataset = dataset['dataset']
dataset = dataset["dataset"]
if 'enabled' not in dataset or not dataset['enabled']:
if "enabled" not in dataset or not dataset["enabled"]:
continue
dataset_id = dataset.get('id', None)
dataset_id = dataset.get("id", None)
if dataset_id:
dataset_ids.append(dataset_id)
if 'agent_mode' in config and config['agent_mode'] \
and 'enabled' in config['agent_mode'] \
and config['agent_mode']['enabled']:
if (
"agent_mode" in config
and config["agent_mode"]
and "enabled" in config["agent_mode"]
and config["agent_mode"]["enabled"]
):
agent_dict = config.get("agent_mode", {})
agent_dict = config.get('agent_mode', {})
for tool in agent_dict.get('tools', []):
for tool in agent_dict.get("tools", []):
keys = tool.keys()
if len(keys) == 1:
# old standard
key = list(tool.keys())[0]
if key != 'dataset':
if key != "dataset":
continue
tool_item = tool[key]
@@ -55,30 +54,28 @@ class DatasetConfigManager:
if "enabled" not in tool_item or not tool_item["enabled"]:
continue
dataset_id = tool_item['id']
dataset_id = tool_item["id"]
dataset_ids.append(dataset_id)
if len(dataset_ids) == 0:
return None
# dataset configs
if 'dataset_configs' in config and config.get('dataset_configs'):
dataset_configs = config.get('dataset_configs')
if "dataset_configs" in config and config.get("dataset_configs"):
dataset_configs = config.get("dataset_configs")
else:
dataset_configs = {
'retrieval_model': 'multiple'
}
query_variable = config.get('dataset_query_variable')
dataset_configs = {"retrieval_model": "multiple"}
query_variable = config.get("dataset_query_variable")
if dataset_configs['retrieval_model'] == 'single':
if dataset_configs["retrieval_model"] == "single":
return DatasetEntity(
dataset_ids=dataset_ids,
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
)
)
dataset_configs["retrieval_model"]
),
),
)
else:
return DatasetEntity(
@@ -86,15 +83,15 @@ class DatasetConfigManager:
retrieve_config=DatasetRetrieveConfigEntity(
query_variable=query_variable,
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
dataset_configs["retrieval_model"]
),
top_k=dataset_configs.get('top_k', 4),
score_threshold=dataset_configs.get('score_threshold'),
reranking_model=dataset_configs.get('reranking_model'),
weights=dataset_configs.get('weights'),
reranking_enabled=dataset_configs.get('reranking_enabled', True),
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
)
top_k=dataset_configs.get("top_k", 4),
score_threshold=dataset_configs.get("score_threshold"),
reranking_model=dataset_configs.get("reranking_model"),
weights=dataset_configs.get("weights"),
reranking_enabled=dataset_configs.get("reranking_enabled", True),
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
),
)
@classmethod
@@ -111,13 +108,10 @@ class DatasetConfigManager:
# dataset_configs
if not config.get("dataset_configs"):
config["dataset_configs"] = {'retrieval_model': 'single'}
config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {
"strategy": "router",
"datasets": []
}
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
@@ -125,8 +119,9 @@ class DatasetConfigManager:
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
need_manual_query_datasets = (config.get("dataset_configs")
and config["dataset_configs"].get("datasets", {}).get("datasets"))
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
).get("datasets")
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
# Only check when mode is completion
@@ -148,10 +143,7 @@ class DatasetConfigManager:
"""
# Extract dataset config for legacy compatibility
if not config.get("agent_mode"):
config["agent_mode"] = {
"enabled": False,
"tools": []
}
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@@ -188,7 +180,7 @@ class DatasetConfigManager:
if not isinstance(tool_item["enabled"], bool):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if 'id' not in tool_item:
if "id" not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
class ModelConfigConverter:
@classmethod
def convert(cls, app_config: EasyUIBasedAppConfig,
skip_check: bool = False) \
-> ModelConfigWithCredentialsEntity:
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
"""
Convert app model config dict to entity.
:param app_config: app config
@@ -25,9 +23,7 @@ class ModelConfigConverter:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=app_config.tenant_id,
provider=model_config.provider,
model_type=ModelType.LLM
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
)
provider_name = provider_model_bundle.configuration.provider.provider
@@ -38,8 +34,7 @@ class ModelConfigConverter:
# check model credentials
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_config.model
model_type=ModelType.LLM, model=model_config.model
)
if model_credentials is None:
@@ -51,8 +46,7 @@ class ModelConfigConverter:
if not skip_check:
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_config.model,
model_type=ModelType.LLM
model=model_config.model, model_type=ModelType.LLM
)
if provider_model is None:
@@ -69,24 +63,18 @@ class ModelConfigConverter:
# model config
completion_params = model_config.parameters
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model_config.mode
if not model_mode:
mode_enum = model_type_instance.get_model_mode(
model=model_config.model,
credentials=model_credentials
)
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
model_mode = mode_enum.value
model_schema = model_type_instance.get_model_schema(
model_config.model,
model_credentials
)
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
if not skip_check and not model_schema:
raise ValueError(f"Model {model_name} not exist.")

View File

@@ -13,23 +13,23 @@ class ModelConfigManager:
:param config: model config args
"""
# model config
model_config = config.get('model')
model_config = config.get("model")
if not model_config:
raise ValueError("model is required")
completion_params = model_config.get('completion_params')
completion_params = model_config.get("completion_params")
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model_config.get('mode')
model_mode = model_config.get("mode")
return ModelConfigEntity(
provider=config['model']['provider'],
model=config['model']['name'],
provider=config["model"]["provider"],
model=config["model"]["name"],
mode=model_mode,
parameters=completion_params,
stop=stop,
@@ -43,7 +43,7 @@ class ModelConfigManager:
:param tenant_id: tenant id
:param config: app model config args
"""
if 'model' not in config:
if "model" not in config:
raise ValueError("model is required")
if not isinstance(config["model"], dict):
@@ -52,17 +52,16 @@ class ModelConfigManager:
# model.provider
provider_entities = model_provider_factory.get_providers()
model_provider_names = [provider.provider for provider in provider_entities]
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
# model.name
if 'name' not in config["model"]:
if "name" not in config["model"]:
raise ValueError("model.name is required")
provider_manager = ProviderManager()
models = provider_manager.get_configurations(tenant_id).get_models(
provider=config["model"]["provider"],
model_type=ModelType.LLM
provider=config["model"]["provider"], model_type=ModelType.LLM
)
if not models:
@@ -80,12 +79,12 @@ class ModelConfigManager:
# model.mode
if model_mode:
config['model']["mode"] = model_mode
config["model"]["mode"] = model_mode
else:
config['model']["mode"] = "completion"
config["model"]["mode"] = "completion"
# model.completion_params
if 'completion_params' not in config["model"]:
if "completion_params" not in config["model"]:
raise ValueError("model.completion_params is required")
config["model"]["completion_params"] = cls.validate_model_completion_params(
@@ -101,7 +100,7 @@ class ModelConfigManager:
raise ValueError("model.completion_params must be of object type")
# stop
if 'stop' not in cp:
if "stop" not in cp:
cp["stop"] = []
elif not isinstance(cp["stop"], list):
raise ValueError("stop in model.completion_params must be of list type")

View File

@@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
if not config.get("prompt_type"):
raise ValueError("prompt_type is required")
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
simple_prompt_template = config.get("pre_prompt", "")
return PromptTemplateEntity(
prompt_type=prompt_type,
simple_prompt_template=simple_prompt_template
)
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
else:
advanced_chat_prompt_template = None
chat_prompt_config = config.get("chat_prompt_config", {})
if chat_prompt_config:
chat_prompt_messages = []
for message in chat_prompt_config.get("prompt", []):
chat_prompt_messages.append({
"text": message["text"],
"role": PromptMessageRole.value_of(message["role"])
})
chat_prompt_messages.append(
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
messages=chat_prompt_messages
)
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
advanced_completion_prompt_template = None
completion_prompt_config = config.get("completion_prompt_config", {})
if completion_prompt_config:
completion_prompt_template_params = {
'prompt': completion_prompt_config['prompt']['text'],
"prompt": completion_prompt_config["prompt"]["text"],
}
if 'conversation_histories_role' in completion_prompt_config:
completion_prompt_template_params['role_prefix'] = {
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
if "conversation_histories_role" in completion_prompt_config:
completion_prompt_template_params["role_prefix"] = {
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
}
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
@@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
return PromptTemplateEntity(
prompt_type=prompt_type,
advanced_chat_prompt_template=advanced_chat_prompt_template,
advanced_completion_prompt_template=advanced_completion_prompt_template
advanced_completion_prompt_template=advanced_completion_prompt_template,
)
@classmethod
@@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
if config['prompt_type'] not in prompt_type_vals:
if config["prompt_type"] not in prompt_type_vals:
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
# chat_prompt_config
@@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
if not isinstance(config["completion_prompt_config"], dict):
raise ValueError("completion_prompt_config must be of object type")
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
raise ValueError("chat_prompt_config or completion_prompt_config is required "
"when prompt_type is advanced")
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
raise ValueError(
"chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
)
model_mode_vals = [mode.value for mode in ModelMode]
if config['model']["mode"] not in model_mode_vals:
if config["model"]["mode"] not in model_mode_vals:
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
if not user_prefix:
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
if not assistant_prefix:
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
if config['model']["mode"] == ModelMode.CHAT.value:
prompt_list = config['chat_prompt_config']['prompt']
if config["model"]["mode"] == ModelMode.CHAT.value:
prompt_list = config["chat_prompt_config"]["prompt"]
if len(prompt_list) > 10:
raise ValueError("prompt messages must be less than 10")

View File

@@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
external_data_tools = config.get("external_data_tools", [])
for external_data_tool in external_data_tools:
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=external_data_tool['variable'],
type=external_data_tool['type'],
config=external_data_tool['config']
variable=external_data_tool["variable"],
type=external_data_tool["type"],
config=external_data_tool["config"],
)
)
# variables and external_data_tools
for variables in config.get('user_input_form', []):
for variables in config.get("user_input_form", []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
if "config" not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=variable['variable'],
type=variable['type'],
config=variable['config']
variable=variable["variable"], type=variable["type"], config=variable["config"]
)
)
elif variable_type in [
@@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
variable_entities.append(
VariableEntity(
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
variable=variable.get("variable"),
description=variable.get("description"),
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options"),
default=variable.get("default"),
)
)
@@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
form_item = item[key]
if 'label' not in form_item:
if "label" not in form_item:
raise ValueError("label is required in user_input_form")
if not isinstance(form_item["label"], str):
raise ValueError("label in user_input_form must be of string type")
if 'variable' not in form_item:
if "variable" not in form_item:
raise ValueError("variable is required in user_input_form")
if not isinstance(form_item["variable"], str):
@@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
if pattern.match(form_item["variable"]) is None:
raise ValueError("variable in user_input_form must be a string, "
"and cannot start with a number")
raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
variables.append(form_item["variable"])
if 'required' not in form_item or not form_item["required"]:
if "required" not in form_item or not form_item["required"]:
form_item["required"] = False
if not isinstance(form_item["required"], bool):
raise ValueError("required in user_input_form must be of boolean type")
if key == "select":
if 'options' not in form_item or not form_item["options"]:
if "options" not in form_item or not form_item["options"]:
form_item["options"] = []
if not isinstance(form_item["options"], list):
raise ValueError("options in user_input_form must be a list of strings")
if "default" in form_item and form_item['default'] \
and form_item["default"] not in form_item["options"]:
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
raise ValueError("default value in user_input_form must be in the options list")
return config, ["user_input_form"]
@@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
typ = tool["type"]
config = tool["config"]
ExternalDataToolFactory.validate_config(
name=typ,
tenant_id=tenant_id,
config=config
)
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
return config, ["external_data_tools"]

View File

@@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
"""
Model Config Entity.
"""
provider: str
model: str
mode: Optional[str] = None
@@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
"""
Advanced Chat Message Entity.
"""
text: str
role: PromptMessageRole
@@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
"""
Advanced Chat Prompt Template Entity.
"""
messages: list[AdvancedChatMessageEntity]
@@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
"""
Role Prefix Entity.
"""
user: str
assistant: str
@@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
Prompt Type.
'simple', 'advanced'
"""
SIMPLE = 'simple'
ADVANCED = 'advanced'
SIMPLE = "simple"
ADVANCED = "advanced"
@classmethod
def value_of(cls, value: str) -> 'PromptType':
def value_of(cls, value: str) -> "PromptType":
"""
Get value of given mode.
@@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid prompt type value {value}')
raise ValueError(f"invalid prompt type value {value}")
prompt_type: PromptType
simple_prompt_template: Optional[str] = None
@@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
"""
variable: str
type: str
config: dict[str, Any] = {}
@@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
Dataset Retrieve Strategy.
'single' or 'multiple'
"""
SINGLE = 'single'
MULTIPLE = 'multiple'
SINGLE = "single"
MULTIPLE = "multiple"
@classmethod
def value_of(cls, value: str) -> 'RetrieveStrategy':
def value_of(cls, value: str) -> "RetrieveStrategy":
"""
Get value of given mode.
@@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid retrieve strategy value {value}')
raise ValueError(f"invalid retrieve strategy value {value}")
query_variable: Optional[str] = None # Only when app mode is completion
retrieve_strategy: RetrieveStrategy
top_k: Optional[int] = None
score_threshold: Optional[float] = .0
rerank_mode: Optional[str] = 'reranking_model'
score_threshold: Optional[float] = 0.0
rerank_mode: Optional[str] = "reranking_model"
reranking_model: Optional[dict] = None
weights: Optional[dict] = None
reranking_enabled: Optional[bool] = True
class DatasetEntity(BaseModel):
"""
Dataset Config Entity.
"""
dataset_ids: list[str]
retrieve_config: DatasetRetrieveConfigEntity
@@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
type: str
config: dict[str, Any] = {}
@@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
"""
Sensitive Word Avoidance Entity.
"""
enabled: bool
voice: Optional[str] = None
language: Optional[str] = None
@@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
"""
Tracing Config Entity.
"""
enabled: bool
tracing_provider: str
class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None
opening_statement: Optional[str] = None
@@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
text_to_speech: Optional[TextToSpeechEntity] = None
trace_config: Optional[TracingConfigEntity] = None
class AppConfig(BaseModel):
"""
Application Config Entity.
"""
tenant_id: str
app_id: str
app_mode: AppMode
@@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
"""
App Model Config From.
"""
ARGS = 'args'
APP_LATEST_CONFIG = 'app-latest-config'
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
ARGS = "args"
APP_LATEST_CONFIG = "app-latest-config"
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
class EasyUIBasedAppConfig(AppConfig):
"""
Easy UI Based App Config Entity.
"""
app_model_config_from: EasyUIBasedAppModelConfigFrom
app_model_config_id: str
app_model_config_dict: dict
@@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
"""
Workflow UI Based App Config Entity.
"""
workflow_id: str

View File

@@ -13,21 +13,19 @@ class FileUploadConfigManager:
:param config: model config args
:param is_vision: if True, the feature is vision feature
"""
file_upload_dict = config.get('file_upload')
file_upload_dict = config.get("file_upload")
if file_upload_dict:
if file_upload_dict.get('image'):
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
if file_upload_dict.get("image"):
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
image_config = {
'number_limits': file_upload_dict['image']['number_limits'],
'transfer_methods': file_upload_dict['image']['transfer_methods']
"number_limits": file_upload_dict["image"]["number_limits"],
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
}
if is_vision:
image_config['detail'] = file_upload_dict['image']['detail']
image_config["detail"] = file_upload_dict["image"]["detail"]
return FileExtraConfig(
image_config=image_config
)
return FileExtraConfig(image_config=image_config)
return None
@@ -49,21 +47,21 @@ class FileUploadConfigManager:
if not config["file_upload"].get("image"):
config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if config["file_upload"]["image"]["enabled"]:
number_limits = config["file_upload"]["image"]["number_limits"]
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
if is_vision:
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
detail = config["file_upload"]["image"]["detail"]
if detail not in ["high", "low"]:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods']
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ['remote_url', 'local_file']:
if method not in ["remote_url", "local_file"]:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
return config, ["file_upload"]

View File

@@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
:param config: model config args
"""
more_like_this = False
more_like_this_dict = config.get('more_like_this')
more_like_this_dict = config.get("more_like_this")
if more_like_this_dict:
if more_like_this_dict.get('enabled'):
if more_like_this_dict.get("enabled"):
more_like_this = True
return more_like_this
@@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
:param config: app model config args
"""
if not config.get("more_like_this"):
config["more_like_this"] = {
"enabled": False
}
config["more_like_this"] = {"enabled": False}
if not isinstance(config["more_like_this"], dict):
raise ValueError("more_like_this must be of dict type")

View File

@@ -1,5 +1,3 @@
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict) -> tuple[str, list]:
@@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
:param config: model config args
"""
# opening statement
opening_statement = config.get('opening_statement')
opening_statement = config.get("opening_statement")
# suggested questions
suggested_questions_list = config.get('suggested_questions')
suggested_questions_list = config.get("suggested_questions")
return opening_statement, suggested_questions_list

View File

@@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get('retriever_resource')
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
if retriever_resource_dict.get('enabled'):
if retriever_resource_dict.get("enabled"):
show_retrieve_source = True
return show_retrieve_source
@@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
:param config: app model config args
"""
if not config.get("retriever_resource"):
config["retriever_resource"] = {
"enabled": False
}
config["retriever_resource"] = {"enabled": False}
if not isinstance(config["retriever_resource"], dict):
raise ValueError("retriever_resource must be of dict type")

View File

@@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
:param config: model config args
"""
speech_to_text = False
speech_to_text_dict = config.get('speech_to_text')
speech_to_text_dict = config.get("speech_to_text")
if speech_to_text_dict:
if speech_to_text_dict.get('enabled'):
if speech_to_text_dict.get("enabled"):
speech_to_text = True
return speech_to_text
@@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
:param config: app model config args
"""
if not config.get("speech_to_text"):
config["speech_to_text"] = {
"enabled": False
}
config["speech_to_text"] = {"enabled": False}
if not isinstance(config["speech_to_text"], dict):
raise ValueError("speech_to_text must be of dict type")

View File

@@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: model config args
"""
suggested_questions_after_answer = False
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
if suggested_questions_after_answer_dict:
if suggested_questions_after_answer_dict.get('enabled'):
if suggested_questions_after_answer_dict.get("enabled"):
suggested_questions_after_answer = True
return suggested_questions_after_answer
@@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
:param config: app model config args
"""
if not config.get("suggested_questions_after_answer"):
config["suggested_questions_after_answer"] = {
"enabled": False
}
config["suggested_questions_after_answer"] = {"enabled": False}
if not isinstance(config["suggested_questions_after_answer"], dict):
raise ValueError("suggested_questions_after_answer must be of dict type")
if "enabled" not in config["suggested_questions_after_answer"] or not \
config["suggested_questions_after_answer"]["enabled"]:
if (
"enabled" not in config["suggested_questions_after_answer"]
or not config["suggested_questions_after_answer"]["enabled"]
):
config["suggested_questions_after_answer"]["enabled"] = False
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):

View File

@@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
:param config: model config args
"""
text_to_speech = None
text_to_speech_dict = config.get('text_to_speech')
text_to_speech_dict = config.get("text_to_speech")
if text_to_speech_dict:
if text_to_speech_dict.get('enabled'):
if text_to_speech_dict.get("enabled"):
text_to_speech = TextToSpeechEntity(
enabled=text_to_speech_dict.get('enabled'),
voice=text_to_speech_dict.get('voice'),
language=text_to_speech_dict.get('language'),
enabled=text_to_speech_dict.get("enabled"),
voice=text_to_speech_dict.get("voice"),
language=text_to_speech_dict.get("language"),
)
return text_to_speech
@@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
:param config: app model config args
"""
if not config.get("text_to_speech"):
config["text_to_speech"] = {
"enabled": False,
"voice": "",
"language": ""
}
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
if not isinstance(config["text_to_speech"], dict):
raise ValueError("text_to_speech must be of dict type")

View File

@@ -1,4 +1,3 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import WorkflowUIBasedAppConfig
@@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
"""
Advanced Chatbot App Config Entity.
"""
pass
class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
workflow: Workflow) -> AdvancedChatAppConfig:
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
@@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
@@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -4,12 +4,10 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, Workflow
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -41,7 +34,8 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@@ -51,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
@@ -60,13 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
):
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -77,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', False)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
conversation_id = args.get("conversation_id")
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
conversation = self._get_conversation_by_user(
app_model=app_model, conversation_id=conversation_id, user=user
)
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@@ -136,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -146,15 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
stream=stream,
)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True):
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -166,43 +152,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
raise ValueError("node_id is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
# get conversation
conversation = None
conversation_id = args.get('conversation_id')
if conversation_id:
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = AdvancedChatAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=app_config,
conversation_id=conversation.id if conversation else None,
conversation_id=None,
inputs={},
query='',
query="",
files=[],
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={"auto_generate_conversation_name": False},
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -211,32 +183,42 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
conversation=conversation,
stream=stream
conversation=None,
stream=stream,
)
def _generate(self, *,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Conversation | None = None,
stream: bool = True):
def _generate(
self,
*,
workflow: Workflow,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
:param workflow: Workflow
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
conversation.override_model_configs = workflow.features
db.session.commit()
# db.session.refresh(conversation)
db.session.refresh(conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -245,73 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
conversation.dialogue_count += 1
conversation_id = conversation.id
conversation_dialogue_count = conversation.dialogue_count
db.session.commit()
db.session.refresh(conversation)
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: conversation_id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
contexts.workflow_variable_pool.set(variable_pool)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
'context': contextvars.copy_context(),
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": contextvars.copy_context(),
},
)
worker_thread.start()
@@ -326,16 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AdvancedChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
context: contextvars.Context) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
context: contextvars.Context,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@@ -349,40 +280,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var.set(val)
with flask_app.app_context():
try:
runner = AdvancedChatAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
# get message
message = self._get_message(message_id)
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner()
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG", "false").lower() == 'true':
if os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(),
user="responding_tts",
tenant_id=tenant_id,
voice=voice
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
@@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b''))
audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ''
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self.match = re.compile(r'[。.!?]')
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.TTS
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices()
values = [voice.get('value') for voice in self.voices]
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get('value')
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event = None
self._runtime_thread = threading.Thread(target=self._runtime).start()
@@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
self.model_instance, self.tenant_id, self.voice)
futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
@@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
self.msg_text += message.event.outputs.get('output', '')
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = ''.join(sentence_arr)
futures_result = self.executor.submit(_invoiceTTS, text_content,
self.model_instance,
self.tenant_id,
self.voice)
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ''
self.msg_text = ""
except Exception as e:
self.logger.warning(e)

View File

@@ -1,145 +1,197 @@
import logging
import os
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.moderation.base import ModerationException
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import App, Message, Workflow
from models.model import App, Conversation, EndUser, Message
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
class AdvancedChatAppRunner(AppRunner):
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
"""
AdvancedChat Application Runner
"""
def run(
def __init__(
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
"""
super().__init__(queue_manager)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
def run(self) -> None:
"""
Run application
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
raise ValueError("Workflow not initialized")
inputs = application_generate_entity.inputs
query = application_generate_entity.query
user_id = None
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
# moderation
if self.handle_input_moderation(
queue_manager=queue_manager,
app_record=app_record,
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id,
):
return
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=message,
query=query,
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
):
return
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
message_id=self.message.id,
):
return
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
):
return
# Init conversation variables
stmt = select(ConversationVariable).where(
ConversationVariable.app_id == self.conversation.app_id,
ConversationVariable.conversation_id == self.conversation.id,
)
with Session(db.engine) as session:
conversation_variables = session.scalars(stmt).all()
if not conversation_variables:
# Create conversation variables if they don't exist.
conversation_variables = [
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
]
session.add_all(conversation_variables)
# Convert database entities to variables.
conversation_variables = [item.to_variable() for item in conversation_variables]
session.commit()
# Increment dialogue count.
self.conversation.dialogue_count += 1
conversation_dialogue_count = self.conversation.dialogue_count
db.session.commit()
# Create a variable pool.
system_inputs = {
SystemVariableKey.QUERY: query,
SystemVariableKey.FILES: files,
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
}
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks.append(WorkflowLoggingCallback())
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
)
generator = workflow_entry.run(
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)
def handle_input_moderation(
self,
queue_manager: AppQueueManager,
app_record: App,
app_generate_entity: AdvancedChatAppGenerateEntity,
inputs: Mapping[str, Any],
@@ -148,7 +200,6 @@ class AdvancedChatAppRunner(AppRunner):
) -> bool:
"""
Handle input moderation
:param queue_manager: application queue manager
:param app_record: app record
:param app_generate_entity: application generate entity
:param inputs: inputs
@@ -167,30 +218,19 @@ class AdvancedChatAppRunner(AppRunner):
message_id=message_id,
)
except ModerationException as e:
self._stream_output(
queue_manager=queue_manager,
text=str(e),
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
)
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True
return False
def handle_annotation_reply(
self,
app_record: App,
message: Message,
query: str,
queue_manager: AppQueueManager,
app_generate_entity: AdvancedChatAppGenerateEntity,
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
) -> bool:
"""
Handle annotation reply
:param app_record: app record
:param message: message
:param query: query
:param queue_manager: application queue manager
:param app_generate_entity: application generate entity
"""
# annotation reply
@@ -203,37 +243,21 @@ class AdvancedChatAppRunner(AppRunner):
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
)
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
self._stream_output(
queue_manager=queue_manager,
text=annotation_reply.content,
stream=app_generate_entity.stream,
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
self._complete_with_stream_output(
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
)
return True
return False
def _stream_output(
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
) -> None:
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
"""
Direct output
:param queue_manager: application queue manager
:param text: text
:param stream: stream
:return:
"""
if stream:
index = 0
for token in text:
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
else:
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
self._publish_event(QueueTextChunkEvent(text=text))
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
self._publish_event(QueueStopEvent(stopped_by=stopped_by))

View File

@@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, Any, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@@ -2,9 +2,8 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union
import contexts
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueRetrieverResourcesEvent,
QueueStopEvent,
@@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from events.message_event import message_was_created
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, EndUser, Message
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowRunStatus,
)
@@ -69,22 +65,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: AdvancedChatTaskState
_task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
# Deprecated
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(
self, application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
self,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize AdvancedChatAppGenerateTaskPipeline.
@@ -106,7 +102,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._workflow = workflow
self._conversation = conversation
self._message = message
# Deprecated
self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files,
@@ -114,12 +109,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
SystemVariableKey.USER_ID: user_id,
}
self._task_state = AdvancedChatTaskState(
usage=LLMUsage.empty_usage()
)
self._task_state = WorkflowTaskState()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._stream_generate_routes = self._get_stream_generate_routes()
self._conversation_name_generate_thread = None
def process(self):
@@ -133,13 +124,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
@@ -156,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {}
if stream_response.metadata:
extras['metadata'] = stream_response.metadata
extras["metadata"] = stream_response.metadata
return ChatbotAppBlockingResponse(
task_id=stream_response.task_id,
@@ -167,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
message_id=self._message.id,
answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[ChatbotAppStreamResponse, Any, None]:
"""
To stream response.
:return:
@@ -185,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@@ -196,20 +187,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@@ -220,9 +215,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -236,38 +231,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
self,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if (message.event
and getattr(message.event, 'metadata', None)
and message.event.metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
elif (hasattr(message.event, 'execution_metadata')
and message.event.execution_metadata
and message.event.execution_metadata.get('is_answer_previous_node', False)
and publisher):
publisher.publish(message=message)
event = message.event
# init fake graph runtime state
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
# init workflow run
workflow_run = self._handle_workflow_run_start()
self._refetch_message()
self._message.workflow_run_id = workflow_run.id
db.session.commit()
@@ -275,137 +270,229 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
db.session.close()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception("Workflow run not initialized.")
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
# reset current route position to 0
self._task_state.current_stream_generate_state.current_route_position = 0
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
# stream outputs when node finished
generator = self._generate_stream_outputs_when_node_finished()
if generator:
yield from generator
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
yield self._workflow_node_finish_to_stream_response(
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs) if event.outputs else None,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation.id,
trace_manager=trace_manager,
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, conversation_id=self._conversation.id, trace_manager=trace_manager
)
if workflow_run:
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
if workflow_run.status == WorkflowRunStatus.FAILED.value:
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
if isinstance(event, QueueStopEvent):
# Save message
self._save_message()
yield self._message_end_to_stream_response()
break
else:
self._queue_manager.publish(
QueueAdvancedChatMessageEndEvent(),
PublishFrom.TASK_PIPELINE
)
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message()
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
break
elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event)
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event)
self._refetch_message()
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
if should_direct_answer:
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(delta_text, self._message.id)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message
self._save_message(graph_runtime_state=graph_runtime_state)
yield self._message_end_to_stream_response()
else:
continue
if publisher:
publisher.publish(None)
# publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(self) -> None:
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
"""
Save message.
:return:
"""
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
self._refetch_message()
self._message.answer = self._task_state.answer
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
if self._task_state.metadata and self._task_state.metadata.get('usage'):
usage = LLMUsage(**self._task_state.metadata['usage'])
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
@@ -422,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
extras=self._application_generate_entity.extras,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -432,331 +519,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
"""
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata.copy()
if "annotation_reply" in extras["metadata"]:
del extras["metadata"]["annotation_reply"]
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
answer_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.ANSWER.value
]
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_config in answer_node_configs:
# get generate route for stream output
answer_node_id = node_config['id']
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
return stream_generate_routes
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get answer start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
# check if it's the first node in the iteration
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
if not target_node:
return []
node_iteration_id = target_node.get('data', {}).get('iteration_id')
# get iteration start node id
for node in nodes:
if node.get('id') == node_iteration_id:
if node.get('data', {}).get('start_node_id') == target_node_id:
return [target_node_id]
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:
]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
"""
Generate stream outputs.
:return:
"""
if not self._task_state.current_stream_generate_state:
return
route_chunks = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position:]
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
value = None
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
if not value_selector:
self._task_state.current_stream_generate_state.current_route_position += 1
continue
route_chunk_node_id = value_selector[0]
if route_chunk_node_id == 'sys':
# system variable
value = contexts.workflow_variable_pool.get().get(value_selector)
if value:
value = value.text
elif route_chunk_node_id in self._iteration_nested_relations:
# it's a iteration variable
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
continue
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
iterator = iteration_state.inputs
if not iterator:
continue
iterator_selector = iterator.get('iterator_selector', [])
if value_selector[1] == 'index':
value = iteration_state.current_index
elif value_selector[1] == 'item':
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
iterator_selector
) else None
else:
# check chunk node id is before current node id or equal to current node id
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
break
latest_node_execution_info = self._task_state.latest_node_execution_info
# get route chunk node execution info
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
if (route_chunk_node_execution_info.node_type == NodeType.LLM
and latest_node_execution_info.node_type == NodeType.LLM):
# only LLM support chunk stream output
self._task_state.current_stream_generate_state.current_route_position += 1
continue
# get route chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
).first()
outputs = route_chunk_node_execution.outputs_dict
# get value from outputs
value = None
for key in value_selector[1:]:
if not value:
value = outputs.get(key) if outputs else None
else:
value = value.get(key)
if value is not None:
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text:
self._task_state.answer += text
yield self._message_to_stream_response(text, self._message.id)
self._task_state.current_stream_generate_state.current_route_position += 1
# all route chunks are generated
if self._task_state.current_stream_generate_state.current_route_position == len(
self._task_state.current_stream_generate_state.generate_route
):
self._task_state.current_stream_generate_state = None
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return True
if 'node_id' not in event.metadata:
return True
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
route_chunk = self._task_state.current_stream_generate_state.generate_route[
self._task_state.current_stream_generate_state.current_route_position]
if route_chunk.type != 'var':
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector
# check chunk node id is before current node id or equal to current node id
if value_selector != stream_output_value_selector:
return False
return True
def _handle_output_moderation_chunk(self, text: str) -> bool:
"""
Handle output moderation chunk.
@@ -768,17 +539,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# stop subscribe new token when output moderation should direct output
self._task_state.answer = self._output_moderation_handler.get_final_output()
self._queue_manager.publish(
QueueTextChunkEvent(
text=self._task_state.answer
), PublishFrom.TASK_PIPELINE
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:
self._output_moderation_handler.append_new_token(text)
return False
def _refetch_message(self) -> None:
"""
Refetch message.
:return:
"""
message = db.session.query(Message).filter(Message.id == self._message.id).first()
if message:
self._message = message

View File

@@ -1,203 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager._publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager._publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self._queue_manager.publish(
event,
PublishFrom.APPLICATION_MANAGER
)

View File

@@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
"""
Agent Chatbot App Config Entity.
"""
agent: Optional[AgentEntity] = None
class AgentChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> AgentChatAppConfig:
"""
Convert app model config to agent chat app config
:param app_model: app model
@@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
agent=AgentConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
agent=AgentConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
# dataset configs
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
@@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
:param config: app model config args
"""
if not config.get("agent_mode"):
config["agent_mode"] = {
"enabled": False,
"tools": []
}
config["agent_mode"] = {"enabled": False, "tools": []}
if not isinstance(config["agent_mode"], dict):
raise ValueError("agent_mode must be of object type")
@@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
if not config["agent_mode"].get("strategy"):
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
if config["agent_mode"]["strategy"] not in [member.value for member in
list(PlanningStrategy.__members__.values())]:
if config["agent_mode"]["strategy"] not in [
member.value for member in list(PlanningStrategy.__members__.values())
]:
raise ValueError("strategy in agent_mode must be in the specified strategy list")
if not config["agent_mode"].get("tools"):
@@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
raise ValueError("enabled in agent_mode.tools must be of boolean type")
if key == "dataset":
if 'id' not in tool_item:
if "id" not in tool_item:
raise ValueError("id is required in dataset")
try:

View File

@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
@@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[dict, None, None]]:
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[dict, None, None]]:
"""
Generate App response.
@@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param stream: is stream
"""
if not stream:
raise ValueError('Agent Chat App does not support blocking mode')
raise ValueError("Agent Chat App does not support blocking mode")
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = AgentChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
override_config_dict=override_model_config_dict,
)
# get tracing instance
@@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
call_depth=0,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
@@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return AgentChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(
self, flask_app: Flask,
self,
flask_app: Flask,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
@@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
"""
def run(
self, application_generate_entity: AgentChatAppGenerateEntity,
self,
application_generate_entity: AgentChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
@@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
memory = None
@@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# moderation
@@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
@@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# reorganize all inputs and template to prompt messages
@@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
@@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
agent_entity = app_config.agent
# load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id)
tool_conversation_variables = self._load_tool_variables(
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
@@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
# init model instance
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
@@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
model_instance=model_instance,
)
invoke_result = runner.run(
@@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True
agent=True,
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id
).first()
tool_variables: ToolConversationVariables = (
db.session.query(ToolConversationVariables)
.filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id,
)
.first()
)
if tool_variables:
# save tool variables to session, so that we can update it later
@@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str='[]',
variables_str="[]",
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
def _convert_db_variables_to_tool_variables(
self, db_variables: ToolConversationVariables
) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(**{
'conversation_id': db_variables.conversation_id,
'user_id': db_variables.user_id,
'tenant_id': db_variables.tenant_id,
'pool': db_variables.variables
})
return ToolRuntimeVariablePool(
**{
"conversation_id": db_variables.conversation_id,
"user_id": db_variables.user_id,
"tenant_id": db_variables.tenant_id,
"pool": db_variables.variables,
}
)
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
message: Message) -> LLMUsage:
def _get_usage_of_all_agent_thoughts(
self, model_config: ModelConfigWithCredentialsEntity, message: Message
) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
agent_thoughts = (
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
)
all_message_tokens = 0
all_answer_tokens = 0
@@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
)

View File

@@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
_blocking_response_type: type[AppBlockingResponse]
@classmethod
def convert(cls, response: Union[
AppBlockingResponse,
Generator[AppStreamResponse, Any, None]
], invoke_from: InvokeFrom):
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> dict[str, Any] | Generator[str, Any, None]:
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
def _generate_full_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_full_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f'data: {chunk}\n\n'
yield f"data: {chunk}\n\n"
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
def _generate_simple_response() -> Generator[str, Any, None]:
for chunk in cls.convert_stream_simple_response(response):
if chunk == 'ping':
yield f'event: {chunk}\n\n'
if chunk == "ping":
yield f"event: {chunk}\n\n"
else:
yield f'data: {chunk}\n\n'
yield f"data: {chunk}\n\n"
return _generate_simple_response()
@@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[str, None, None]:
raise NotImplementedError
@classmethod
@@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
# show_retrieve_source
if 'retriever_resources' in metadata:
metadata['retriever_resources'] = []
for resource in metadata['retriever_resources']:
metadata['retriever_resources'].append({
'segment_id': resource['segment_id'],
'position': resource['position'],
'document_name': resource['document_name'],
'score': resource['score'],
'content': resource['content'],
})
if "retriever_resources" in metadata:
metadata["retriever_resources"] = []
for resource in metadata["retriever_resources"]:
metadata["retriever_resources"].append(
{
"segment_id": resource["segment_id"],
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
}
)
# show annotation reply
if 'annotation_reply' in metadata:
del metadata['annotation_reply']
if "annotation_reply" in metadata:
del metadata["annotation_reply"]
# show usage
if 'usage' in metadata:
del metadata['usage']
if "usage" in metadata:
del metadata["usage"]
return metadata
@@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
:return:
"""
error_responses = {
ValueError: {'code': 'invalid_param', 'status': 400},
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
ValueError: {"code": "invalid_param", "status": 400},
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
QuotaExceededError: {
'code': 'provider_quota_exceeded',
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
'status': 400
"code": "provider_quota_exceeded",
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.",
"status": 400,
},
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
InvokeError: {'code': 'completion_request_error', 'status': 400}
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
InvokeError: {"code": "completion_request_error", "status": 400},
}
# Determine the response based on the type of exception
@@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
data = v
if data:
data.setdefault('message', getattr(e, 'description', str(e)))
data.setdefault("message", getattr(e, "description", str(e)))
else:
logging.error(e)
data = {
'code': 'internal_server_error',
'message': 'Internal Server Error, please contact support.',
'status': 500
"code": "internal_server_error",
"message": "Internal Server Error, please contact support.",
"status": 500,
}
return data

View File

@@ -16,10 +16,10 @@ class BaseAppGenerator:
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.variable} is required in input form')
raise ValueError(f"{var.variable} is required in input form")
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
return var.default or ""
if (
var.type
in (
@@ -34,7 +34,7 @@ class BaseAppGenerator:
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
if "." in user_input_value:
return float(user_input_value)
else:
return int(user_input_value)
@@ -43,14 +43,14 @@ class BaseAppGenerator:
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
return user_input_value
def _sanitize_value(self, value: Any) -> Any:
if isinstance(value, str):
return value.replace('\x00', '')
return value.replace("\x00", "")
return value

View File

@@ -24,9 +24,7 @@ class PublishFrom(Enum):
class AppQueueManager:
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
if not user_id:
raise ValueError("user is required")
@@ -34,9 +32,10 @@ class AppQueueManager:
self._user_id = user_id
self._invoke_from = invoke_from
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
f"{user_prefix}-{self._user_id}")
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
redis_client.setex(
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue()
@@ -66,8 +65,7 @@ class AppQueueManager:
# publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed
self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
)
if elapsed_time // 10 > last_ping_time:
@@ -88,9 +86,7 @@ class AppQueueManager:
:param pub_from: publish from
:return:
"""
self.publish(QueueErrorEvent(
error=e
), pub_from)
self.publish(QueueErrorEvent(error=e), pub_from)
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
@@ -122,8 +118,8 @@ class AppQueueManager:
if result is None:
return
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
return
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
@@ -168,9 +164,11 @@ class AppQueueManager:
for item in data:
self._check_for_sqlalchemy_models(item)
else:
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed.")
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances "
"that cause thread safety issues is not allowed."
)
class GenerateTaskStoppedException(Exception):

View File

@@ -1,6 +1,6 @@
import time
from collections.abc import Generator
from typing import TYPE_CHECKING, Optional, Union
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@@ -31,12 +31,15 @@ if TYPE_CHECKING:
class AppRunner:
def get_pre_calculate_rest_tokens(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None) -> int:
def get_pre_calculate_rest_tokens(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
) -> int:
"""
Get pre calculate rest tokens
:param app_record: app record
@@ -49,18 +52,20 @@ class AppRunner:
"""
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None:
return -1
@@ -75,36 +80,39 @@ class AppRunner:
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
query=query
query=query,
)
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
if rest_tokens < 0:
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size.")
raise InvokeBadRequestError(
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
"or shrink the max token, or switch to a llm with a larger token limit size."
)
return rest_tokens
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage]):
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)
) or 0
if model_context_tokens is None:
return -1
@@ -112,27 +120,28 @@ class AppRunner:
if max_tokens is None:
max_tokens = 0
prompt_tokens = model_instance.get_llm_num_tokens(
prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
if prompt_tokens + max_tokens > model_context_tokens:
max_tokens = max(model_context_tokens - prompt_tokens, 16)
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
model_config.parameters[parameter_rule.name] = max_tokens
def organize_prompt_messages(self, app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None) \
-> tuple[list[PromptMessage], Optional[list[str]]]:
def organize_prompt_messages(
self,
app_record: App,
model_config: ModelConfigWithCredentialsEntity,
prompt_template_entity: PromptTemplateEntity,
inputs: dict[str, str],
files: list["FileVar"],
query: Optional[str] = None,
context: Optional[str] = None,
memory: Optional[TokenBufferMemory] = None,
) -> tuple[list[PromptMessage], Optional[list[str]]]:
"""
Organize prompt messages
:param context:
@@ -152,60 +161,54 @@ class AppRunner:
app_mode=AppMode.value_of(app_record.mode),
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query if query else '',
query=query if query else "",
files=files,
context=context,
memory=memory,
model_config=model_config
model_config=model_config,
)
else:
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
prompt_template = CompletionModelPromptTemplate(
text=advanced_completion_prompt_template.prompt
)
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
if advanced_completion_prompt_template.role_prefix:
memory_config.role_prefix = MemoryConfig.RolePrefix(
user=advanced_completion_prompt_template.role_prefix.user,
assistant=advanced_completion_prompt_template.role_prefix.assistant
assistant=advanced_completion_prompt_template.role_prefix.assistant,
)
else:
prompt_template = []
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
prompt_template.append(ChatModelMessage(
text=message.text,
role=message.role
))
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
prompt_transform = AdvancedPromptTransform()
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs=inputs,
query=query if query else '',
query=query if query else "",
files=files,
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config
model_config=model_config,
)
stop = model_config.stop
return prompt_messages, stop
def direct_output(self, queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None) -> None:
def direct_output(
self,
queue_manager: AppQueueManager,
app_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list,
text: str,
stream: bool,
usage: Optional[LLMUsage] = None,
) -> None:
"""
Direct output
:param queue_manager: application queue manager
@@ -222,17 +225,10 @@ class AppRunner:
chunk = LLMResultChunk(
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=token)
)
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
)
queue_manager.publish(
QueueLLMChunkEvent(
chunk=chunk
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
index += 1
time.sleep(0.01)
@@ -242,15 +238,19 @@ class AppRunner:
model=app_generate_entity.model_conf.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage()
usage=usage if usage else LLMUsage.empty_usage(),
),
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False) -> None:
def _handle_invoke_result(
self,
invoke_result: Union[LLMResult, Generator],
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@@ -260,21 +260,13 @@ class AppRunner:
:return:
"""
if not stream:
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
else:
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent
)
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
queue_manager: AppQueueManager,
agent: bool) -> None:
def _handle_invoke_result_direct(
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
) -> None:
"""
Handle invoke result direct
:param invoke_result: invoke result
@@ -285,12 +277,13 @@ class AppRunner:
queue_manager.publish(
QueueMessageEndEvent(
llm_result=invoke_result,
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_result_stream(self, invoke_result: Generator,
queue_manager: AppQueueManager,
agent: bool) -> None:
def _handle_invoke_result_stream(
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
) -> None:
"""
Handle invoke result
:param invoke_result: invoke result
@@ -300,21 +293,13 @@ class AppRunner:
"""
model = None
prompt_messages = []
text = ''
text = ""
usage = None
for result in invoke_result:
if not agent:
queue_manager.publish(
QueueLLMChunkEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
else:
queue_manager.publish(
QueueAgentMessageEvent(
chunk=result
), PublishFrom.APPLICATION_MANAGER
)
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
text += result.delta.message.content
@@ -331,25 +316,24 @@ class AppRunner:
usage = LLMUsage.empty_usage()
llm_result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
usage=usage
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
)
queue_manager.publish(
QueueMessageEndEvent(
llm_result=llm_result,
), PublishFrom.APPLICATION_MANAGER
),
PublishFrom.APPLICATION_MANAGER,
)
def moderation_for_inputs(
self, app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: dict,
query: str,
message_id: str,
self,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str,
message_id: str,
) -> tuple[bool, dict, str]:
"""
Process sensitive_word_avoidance.
@@ -367,14 +351,17 @@ class AppRunner:
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=inputs,
query=query if query else '',
query=query if query else "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager
trace_manager=app_generate_entity.trace_manager,
)
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage]) -> bool:
def check_hosting_moderation(
self,
application_generate_entity: EasyUIBasedAppGenerateEntity,
queue_manager: AppQueueManager,
prompt_messages: list[PromptMessage],
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
@@ -384,8 +371,7 @@ class AppRunner:
"""
hosting_moderation_feature = HostingModerationFeature()
moderation_result = hosting_moderation_feature.check(
application_generate_entity=application_generate_entity,
prompt_messages=prompt_messages
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
)
if moderation_result:
@@ -393,18 +379,20 @@ class AppRunner:
queue_manager=queue_manager,
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text="I apologize for any confusion, " \
"but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream
text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
stream=application_generate_entity.stream,
)
return moderation_result
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str) -> dict:
def fill_in_inputs_from_external_data_tools(
self,
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
query: str,
) -> dict:
"""
Fill in variable inputs from external data tools if exists.
@@ -417,18 +405,12 @@ class AppRunner:
"""
external_data_fetch_feature = ExternalDataFetch()
return external_data_fetch_feature.fetch(
tenant_id=tenant_id,
app_id=app_id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
)
def query_app_annotations_to_reply(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
@@ -440,9 +422,5 @@ class AppRunner:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record,
message=message,
query=query,
user_id=user_id,
invoke_from=invoke_from
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)

View File

@@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
"""
Chatbot App Config Entity.
"""
pass
class ChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
def get_app_config(
cls,
app_model: App,
app_model_config: AppModelConfig,
conversation: Optional[Conversation] = None,
override_config_dict: Optional[dict] = None,
) -> ChatAppConfig:
"""
Convert app model config to chat app config
:param app_model: app model
@@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
config_dict = app_model_config_dict.copy()
else:
if not override_config_dict:
raise Exception('override_config_dict is required when config_from is ARGS')
raise Exception("override_config_dict is required when config_from is ARGS")
config_dict = override_config_dict
@@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# opening_statement
@@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
# suggested_questions_after_answer
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
config)
config
)
related_config_keys.extend(current_related_config_keys)
# speech_to_text
@@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))

View File

@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
@@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
if not args.get('query'):
raise ValueError('query is required')
if not args.get("query"):
raise ValueError("query is required")
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {
"auto_generate_conversation_name": args.get('auto_generate_name', True)
}
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
# get conversation
conversation = None
if args.get('conversation_id'):
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
if args.get("conversation_id"):
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = ChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {
"enabled": True
}
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
@@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
app_model=app_model,
app_model_config=app_model_config,
conversation=conversation,
override_config_dict=override_model_config_dict
override_config_dict=override_model_config_dict,
)
# get tracing instance
@@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity, conversation)
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'conversation_id': conversation.id,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread.start()
@@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return ChatAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation_id: str,
message_id: str,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message
message=message,
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:

View File

@@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
Chat Application Runner
"""
def run(self, application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message) -> None:
def run(
self,
application_generate_entity: ChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
memory = None
@@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
@@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
memory=memory
memory=memory,
)
# moderation
@@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
invoke_from=application_generate_entity.invoke_from,
)
if annotation_reply:
queue_manager.publish(
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
PublishFrom.APPLICATION_MANAGER
PublishFrom.APPLICATION_MANAGER,
)
self.direct_output(
@@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# get context from datasets
@@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
application_generate_entity.invoke_from,
)
dataset_retrieval = DatasetRetrieval(application_generate_entity)
@@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
files=files,
query=query,
context=context,
memory=memory
memory=memory,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
db.session.close()
@@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'conversation_id': blocking_response.data.conversation_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"conversation_id": blocking_response.data.conversation_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'conversation_id': chunk.conversation_id,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"conversation_id": chunk.conversation_id,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
"""
Completion App Config Entity.
"""
pass
class CompletionAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App,
app_model_config: AppModelConfig,
override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
def get_app_config(
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
) -> CompletionAppConfig:
"""
Convert app model config to completion app config
:param app_model: app model
@@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
app_model_config_from=config_from,
app_model_config_id=app_model_config.id,
app_model_config_dict=config_dict,
model=ModelConfigManager.convert(
config=config_dict
),
prompt_template=PromptTemplateConfigManager.convert(
config=config_dict
),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=config_dict
),
dataset=DatasetConfigManager.convert(
config=config_dict
),
additional_features=cls.convert_features(config_dict, app_mode)
model=ModelConfigManager.convert(config=config_dict),
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
dataset=DatasetConfigManager.convert(config=config_dict),
additional_features=cls.convert_features(config_dict, app_mode),
)
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
@@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# dataset_query_variable
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
config)
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
tenant_id, app_mode, config
)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
@@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
config)
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id, config
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))

View File

@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
@@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
def generate(self, app_model: App,
user: Union[Account, EndUser],
args: Any,
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
@@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
query = args['query']
query = args["query"]
if not isinstance(query, str):
raise ValueError('query must be a string')
raise ValueError("query must be a string")
query = query.replace('\x00', '')
inputs = args['inputs']
query = query.replace("\x00", "")
inputs = args["inputs"]
extras = {}
@@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
conversation = None
# get app model config
app_model_config = self._get_app_model_config(
app_model=app_model,
conversation=conversation
)
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
# validate override model config
override_model_config_dict = None
if args.get('model_config'):
if args.get("model_config"):
if invoke_from != InvokeFrom.DEBUGGER:
raise ValueError('Only in App debug mode can override model config')
raise ValueError("Only in App debug mode can override model config")
# validate config
override_model_config_dict = CompletionAppConfigManager.config_validate(
tenant_id=app_model.tenant_id,
config=args.get('model_config')
tenant_id=app_model.tenant_id, config=args.get("model_config")
)
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# get tracing instance
@@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager
trace_manager=trace_manager,
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start()
@@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message_id: str,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
@@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message=message
message=message,
)
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
@@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
finally:
db.session.close()
def generate_more_like_this(self, app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True) \
-> Union[dict, Generator[str, None, None]]:
def generate_more_like_this(
self,
app_model: App,
message_id: str,
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
"""
Generate App response.
@@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
message = (
db.session.query(Message)
.filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
)
if not message:
raise MessageNotExistsError()
@@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
app_model_config = message.app_model_config
override_model_config_dict = app_model_config.to_dict()
model_dict = override_model_config_dict['model']
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
override_model_config_dict['model'] = model_dict
model_dict = override_model_config_dict["model"]
completion_params = model_dict.get("completion_params")
completion_params["temperature"] = 0.9
model_dict["completion_params"] = completion_params
override_model_config_dict["model"] = model_dict
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
message.files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = CompletionAppConfigManager.get_app_config(
app_model=app_model,
app_model_config=app_model_config,
override_config_dict=override_model_config_dict
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
)
# init application generate entity
@@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=invoke_from,
extras={}
extras={},
)
# init generate records
(
conversation,
message
) = self._init_generate_records(application_generate_entity)
(conversation, message) = self._init_generate_records(application_generate_entity)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
app_mode=conversation.mode,
message_id=message.id
message_id=message.id,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'message_id': message.id,
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(),
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"message_id": message.id,
},
)
worker_thread.start()
@@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
stream=stream,
)
return CompletionAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)

View File

@@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
Completion Application Runner
"""
def run(self, application_generate_entity: CompletionAppGenerateEntity,
queue_manager: AppQueueManager,
message: Message) -> None:
def run(
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
) -> None:
"""
Run application
:param application_generate_entity: application generate entity
@@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
# organize all inputs and template to prompt messages
@@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
prompt_template_entity=app_config.prompt_template,
inputs=inputs,
files=files,
query=query
query=query,
)
# moderation
@@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
inputs=inputs,
query=query,
message_id=message.id
message_id=message.id,
)
except ModerationException as e:
self.direct_output(
@@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
app_generate_entity=application_generate_entity,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
stream=application_generate_entity.stream,
)
return
@@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
query=query,
)
# get context from datasets
@@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
app_record.id,
message.id,
application_generate_entity.user_id,
application_generate_entity.invoke_from
application_generate_entity.invoke_from,
)
dataset_config = app_config.dataset
@@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
invoke_from=application_generate_entity.invoke_from,
show_retrieve_source=app_config.additional_features.show_retrieve_source,
hit_callback=hit_callback,
message_id=message.id
message_id=message.id,
)
# reorganize all inputs and template to prompt messages
@@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
inputs=inputs,
files=files,
query=query,
context=context
context=context,
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
prompt_messages=prompt_messages,
)
if hosting_moderation_result:
return
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
self.recalc_llm_max_tokens(
model_config=application_generate_entity.model_conf,
prompt_messages=prompt_messages
)
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
# Invoke model
model_instance = ModelInstance(
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
model=application_generate_entity.model_conf.model
model=application_generate_entity.model_conf.model,
)
db.session.close()
@@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
)

View File

@@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
:return:
"""
response = {
'event': 'message',
'task_id': blocking_response.task_id,
'id': blocking_response.data.id,
'message_id': blocking_response.data.message_id,
'mode': blocking_response.data.mode,
'answer': blocking_response.data.answer,
'metadata': blocking_response.data.metadata,
'created_at': blocking_response.data.created_at
"event": "message",
"task_id": blocking_response.task_id,
"id": blocking_response.data.id,
"message_id": blocking_response.data.message_id,
"mode": blocking_response.data.mode,
"answer": blocking_response.data.answer,
"metadata": blocking_response.data.metadata,
"created_at": blocking_response.data.created_at,
}
return response
@@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
"""
response = cls.convert_blocking_full_response(blocking_response)
metadata = response.get('metadata', {})
response['metadata'] = cls._get_simple_metadata(metadata)
metadata = response.get("metadata", {})
response["metadata"] = cls._get_simple_metadata(metadata)
return response
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'message_id': chunk.message_id,
'created_at': chunk.created_at
"event": sub_stream_response.event.value,
"message_id": chunk.message_id,
"created_at": chunk.created_at,
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.to_dict()
metadata = sub_stream_response_dict.get('metadata', {})
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)

View File

@@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
class MessageBasedAppGenerator(BaseAppGenerator):
def _handle_response(
self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Handle response.
@@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation=conversation,
message=message,
user=user,
stream=stream
stream=stream,
)
try:
@@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
logger.exception(e)
raise e
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
user: Union[Account, EndUser]) -> Conversation:
def _get_conversation_by_user(
self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
) -> Conversation:
conversation_filter = [
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
Conversation.status == "normal",
]
if isinstance(user, Account):
@@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
if conversation.status != "normal":
raise ConversationCompletedError()
return conversation
def _get_app_model_config(self, app_model: App,
conversation: Optional[Conversation] = None) \
-> AppModelConfig:
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
app_model_config = (
db.session.query(AppModelConfig)
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
if not app_model_config:
raise AppModelConfigBrokenError()
@@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return app_model_config
def _init_generate_records(self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
],
conversation: Optional[Conversation] = None) \
-> tuple[Conversation, Message]:
def _init_generate_records(
self,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
conversation: Optional[Conversation] = None,
) -> tuple[Conversation, Message]:
"""
Initialize generate records
:param application_generate_entity: application generate entity
@@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
end_user_id = None
account_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
from_source = 'api'
from_source = "api"
end_user_id = application_generate_entity.user_id
else:
from_source = 'console'
from_source = "console"
account_id = application_generate_entity.user_id
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
@@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_provider = application_generate_entity.model_conf.provider
model_id = application_generate_entity.model_conf.model
override_model_configs = None
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT,
AppMode.COMPLETION,
]:
override_model_configs = app_config.app_model_config_dict
# get conversation introduction
@@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
model_id=model_id,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=app_config.app_mode.value,
name='New conversation',
name="New conversation",
inputs=application_generate_entity.inputs,
introduction=introduction,
system_instruction="",
system_instruction_tokens=0,
status='normal',
status="normal",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
@@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency='USD',
currency="USD",
invoke_from=application_generate_entity.invoke_from.value,
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id
from_account_id=account_id,
)
db.session.add(message)
@@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
belongs_to='user',
belongs_to="user",
url=file.url,
upload_file_id=file.related_id,
created_by_role=('account' if account_id else 'end_user'),
created_by_role=("account" if account_id else "end_user"),
created_by=account_id or end_user_id,
)
db.session.add(message_file)
@@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError()
@@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = (
db.session.query(Message)
.filter(Message.id == message_id)
.first()
)
message = db.session.query(Message).filter(Message.id == message_id).first()
return message

View File

@@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
class MessageBasedAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
conversation_id: str,
app_mode: str,
message_id: str) -> None:
def __init__(
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
) -> None:
super().__init__(task_id, user_id, invoke_from)
self._conversation_id = str(conversation_id)
@@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
event=event,
)
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
@@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
message_id=self._message_id,
conversation_id=self._conversation_id,
app_mode=self._app_mode,
event=event
event=event,
)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueAdvancedChatMessageEndEvent):
if isinstance(
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException()

View File

@@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
"""
Workflow App Config Entity.
"""
pass
@@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
app_id=app_model.id,
app_mode=app_mode,
workflow_id=workflow.id,
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
config=features_dict
),
variables=WorkflowVariablesConfigManager.convert(
workflow=workflow
),
additional_features=cls.convert_features(features_dict, app_mode)
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
additional_features=cls.convert_features(features_dict, app_mode),
)
return app_config
@@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
config=config,
is_vision=False
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys)
@@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id,
config=config,
only_structure_validate=only_structure_validate
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)

View File

@@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -34,32 +34,40 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> dict: ...
def generate(
self, app_model: App,
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: dict,
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
"""
Generate App response.
@@ -71,27 +79,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
inputs = args["inputs"]
# parse files
files = args['files'] if args.get('files') else []
files = args["files"] if args.get("files") else []
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
file_extra_config,
user
)
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
else:
file_objs = []
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
@@ -107,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream,
invoke_from=invoke_from,
call_depth=call_depth,
trace_manager=trace_manager
trace_manager=trace_manager,
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -118,16 +120,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id,
)
def _generate(
self, app_model: App,
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
) -> Union[dict, Generator[str, None, None]]:
workflow_thread_pool_id: Optional[str] = None,
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@@ -137,22 +143,27 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=app_model.mode
app_mode=app_model.mode,
)
# new thread
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context()
})
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": contextvars.copy_context(),
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
@@ -165,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
stream=stream,
)
return WorkflowAppGenerateResponseConverter.convert(
response=response,
invoke_from=invoke_from
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(self, app_model: App,
workflow: Workflow,
node_id: str,
user: Account,
args: dict,
stream: bool = True):
def single_iteration_generate(
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
) -> dict[str, Any] | Generator[str, Any, None]:
"""
Generate App response.
@@ -187,20 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param stream: is stream
"""
if not node_id:
raise ValueError('node_id is required')
raise ValueError("node_id is required")
if args.get('inputs') is None:
raise ValueError('inputs is required')
extras = {
"auto_generate_conversation_name": False
}
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
app_config = WorkflowAppConfigManager.get_app_config(
app_model=app_model,
workflow=workflow
)
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@@ -211,11 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
user_id=user.id,
stream=stream,
invoke_from=InvokeFrom.DEBUGGER,
extras=extras,
extras={"auto_generate_conversation_name": False},
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
node_id=node_id,
inputs=args['inputs']
)
node_id=node_id, inputs=args["inputs"]
),
)
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
@@ -225,18 +222,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
stream=stream
stream=stream,
)
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@@ -244,50 +246,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
with flask_app.app_context():
try:
# workflow app
runner = WorkflowAppRunner()
if application_generate_entity.single_iteration_run:
single_iteration_run = application_generate_entity.single_iteration_run
runner.single_iteration_run(
app_id=application_generate_entity.app_config.app_id,
workflow_id=application_generate_entity.app_config.workflow_id,
queue_manager=queue_manager,
inputs=single_iteration_run.inputs,
node_id=single_iteration_run.node_id,
user_id=application_generate_entity.user_id
)
else:
runner.run(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
)
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
)
runner.run()
except GenerateTaskStoppedException:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e:
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.remove()
db.session.close()
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool = False) -> Union[
WorkflowAppBlockingResponse,
Generator[WorkflowAppStreamResponse, None, None]
]:
def _handle_response(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
@@ -303,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream
stream=stream,
)
try:

View File

@@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
class WorkflowAppQueueManager(AppQueueManager):
def __init__(self, task_id: str,
user_id: str,
invoke_from: InvokeFrom,
app_mode: str) -> None:
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
@@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(
task_id=self._task_id,
app_mode=self._app_mode,
event=event
)
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():

View File

@@ -4,129 +4,125 @@ from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App, EndUser
from models.workflow import Workflow
from models.workflow import WorkflowType
logger = logging.getLogger(__name__)
class WorkflowAppRunner:
class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
Workflow Application Runner
"""
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
"""
Run application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:return:
"""
app_config = application_generate_entity.app_config
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = application_generate_entity.user_id
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError('App not found')
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
inputs = application_generate_entity.inputs
files = application_generate_entity.files
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = [
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
]
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
workflow_callbacks: list[WorkflowCallback] = []
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
workflow_callbacks.append(WorkflowLoggingCallback())
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
}
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
# RUN WORKFLOW
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.run_workflow(
workflow=workflow,
user_id=application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER,
invoke_from=application_generate_entity.invoke_from,
callbacks=workflow_callbacks,
call_depth=application_generate_entity.call_depth,
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)
def single_iteration_run(
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
) -> None:
"""
Single iteration run
"""
app_record = db.session.query(App).filter(App.id == app_id).first()
if not app_record:
raise ValueError('App not found')
generator = workflow_entry.run(callbacks=workflow_callbacks)
if not app_record.workflow_id:
raise ValueError('Workflow not initialized')
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
if not workflow:
raise ValueError('Workflow not initialized')
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
workflow_engine_manager = WorkflowEngineManager()
workflow_engine_manager.single_step_run_iteration_workflow_node(
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
for event in generator:
self._handle_event(workflow_entry, event)

View File

@@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_full_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id,
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
@@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
yield json.dumps(response_chunk)
@classmethod
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
-> Generator[str, None, None]:
def convert_stream_simple_response(
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
) -> Generator[str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
@@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield 'ping'
yield "ping"
continue
response_chunk = {
'event': sub_stream_response.event.value,
'workflow_run_id': chunk.workflow_run_id,
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
from collections.abc import Generator
@@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueuePingEvent,
QueueStopEvent,
QueueTextChunkEvent,
@@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
MessageAudioStreamResponse,
StreamResponse,
TextChunkStreamResponse,
TextReplaceStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@@ -52,8 +52,8 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus,
)
logger = logging.getLogger(__name__)
@@ -63,18 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any]
_iteration_nested_relations: dict[str, list[str]]
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@@ -93,14 +96,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._workflow = workflow
self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id
SystemVariableKey.USER_ID: user_id,
}
self._task_state = WorkflowTaskState(
iteration_nested_node_ids=[]
)
self._stream_generate_nodes = self._get_stream_generate_nodes()
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
self._task_state = WorkflowTaskState()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -111,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
db.session.refresh(self._user)
db.session.close()
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
-> WorkflowAppBlockingResponse:
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
"""
To blocking response.
:return:
@@ -129,43 +125,42 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowFinishStreamResponse):
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=workflow_run.id,
workflow_run_id=stream_response.data.id,
data=WorkflowAppBlockingResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status,
outputs=workflow_run.outputs_dict,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp())
)
id=stream_response.data.id,
workflow_id=stream_response.data.workflow_id,
status=stream_response.data.status,
outputs=stream_response.data.outputs,
error=stream_response.data.error,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
),
)
return response
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[WorkflowAppStreamResponse, None, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[WorkflowAppStreamResponse, None, None]:
"""
To stream response.
:return:
"""
workflow_run_id = None
for stream_response in generator:
yield WorkflowAppStreamResponse(
workflow_run_id=self._task_state.workflow_run_id,
stream_response=stream_response
)
if isinstance(stream_response, WorkflowStartStreamResponse):
workflow_run_id = stream_response.workflow_run_id
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listenAudioMsg(self, publisher, task_id: str):
if not publisher:
@@ -175,20 +170,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
publisher = None
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tts_publisher = None
task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
'text_to_speech'].get('autoPlay') == 'enabled':
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
if (
features_dict.get("text_to_speech")
and features_dict["text_to_speech"].get("enabled")
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
):
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@@ -198,9 +197,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
start_listener_time = time.time()
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not publisher:
if not tts_publisher:
break
audio_trunk = publisher.checkAndGetAudio()
audio_trunk = tts_publisher.checkAndGetAudio()
if audio_trunk is None:
# release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
@@ -213,105 +212,176 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
except Exception as e:
logger.error(e)
break
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
:return:
"""
for message in self._queue_manager.listen():
if publisher:
publisher.publish(message=message)
event = message.event
graph_runtime_state = None
workflow_run = None
if isinstance(event, QueueErrorEvent):
for queue_message in self._queue_manager.listen():
event = queue_message.event
if isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event)
yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueWorkflowStartedEvent):
workflow_run = self._handle_workflow_start()
# override graph runtime state
graph_runtime_state = event.graph_runtime_state
# init workflow run
workflow_run = self._handle_workflow_run_start()
yield self._workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
if not workflow_run:
raise Exception("Workflow run not initialized.")
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
if response:
yield response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeFailedEvent):
yield from self._handle_iteration_exception(
task_id=self._application_generate_entity.task_id,
error=f'Child node failed: {event.error}'
)
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
if isinstance(event, QueueIterationNextEvent):
# clear ran node execution infos of current iteration
iteration_relations = self._iteration_nested_relations.get(event.node_id)
if iteration_relations:
for node_id in iteration_relations:
self._task_state.ran_node_execution_infos.pop(node_id, None)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
self._handle_iteration_operation(event)
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
workflow_run = self._handle_workflow_finished(
event, trace_manager=trace_manager
response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationStartEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationNextEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=json.dumps(event.outputs)
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
else None,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
yield self._text_replace_to_stream_response(event.text)
elif isinstance(event, QueuePingEvent):
yield self._ping_stream_response()
else:
continue
if publisher:
publisher.publish(None)
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
"""
@@ -329,15 +399,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# not save log for debugging
return
workflow_app_log = WorkflowAppLog(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
workflow_run_id=workflow_run.id,
created_from=created_from.value,
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
created_by=self._user.id,
)
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = workflow_run.tenant_id
workflow_app_log.app_id = workflow_run.app_id
workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
workflow_app_log.created_by = self._user.id
db.session.add(workflow_app_log)
db.session.commit()
db.session.close()
@@ -349,185 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return:
"""
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=TextChunkStreamResponse.Data(text=text)
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
)
return response
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
"""
Text replace to stream response.
:param text: text
:return:
"""
return TextReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
node_iteration_id = source_node.get('data', {}).get('iteration_id')
iteration_start_node_id = None
if node_iteration_id:
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value or \
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
db.session.close()
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
Get iteration nested relations.
:param graph: graph
:return:
"""
nodes = graph.get('nodes')
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]]
return {
iteration_id: [
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

View File

@@ -1,200 +0,0 @@
from typing import Any, Optional
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from models.workflow import Workflow
class WorkflowEventTriggerCallback(WorkflowCallback):
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self._queue_manager.publish(
QueueWorkflowStartedEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self._queue_manager.publish(
QueueWorkflowSucceededEvent(),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self._queue_manager.publish(
QueueWorkflowFailedEvent(
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
"""
Workflow node execute started
"""
self._queue_manager.publish(
QueueNodeStartedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
node_run_index=node_run_index,
predecessor_node_id=predecessor_node_id
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
"""
Workflow node execute succeeded
"""
self._queue_manager.publish(
QueueNodeSucceededEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
process_data=process_data,
outputs=outputs,
execution_metadata=execution_metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
"""
Workflow node execute failed
"""
self._queue_manager.publish(
QueueNodeFailedEvent(
node_id=node_id,
node_type=node_type,
node_data=node_data,
inputs=inputs,
outputs=outputs,
process_data=process_data,
error=error
),
PublishFrom.APPLICATION_MANAGER
)
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
self._queue_manager.publish(
QueueIterationStartEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
node_data=node_data,
inputs=inputs,
predecessor_node_id=predecessor_node_id,
metadata=metadata
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any]) -> None:
"""
Publish iteration next
"""
self._queue_manager.publish(
QueueIterationNextEvent(
node_id=node_id,
node_type=node_type,
index=index,
node_run_index=node_run_index,
output=output
),
PublishFrom.APPLICATION_MANAGER
)
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
self._queue_manager.publish(
QueueIterationCompletedEvent(
node_id=node_id,
node_type=node_type,
node_run_index=node_run_index,
outputs=outputs
),
PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
pass

View File

@@ -0,0 +1,371 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.iteration.entities import IterationNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager):
self.queue_manager = queue_manager
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
Init graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph
def _get_graph_and_variable_pool_of_single_iteration(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict,
) -> tuple[Graph, VariablePool]:
"""
Get variable pool of single iteration
"""
# fetch workflow graph
graph_config = workflow.graph_dict
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in iteration
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in iteration
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
# init graph
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
iteration_node_config = None
for node in node_configs:
if node.get("id") == node_id:
iteration_node_config = node
break
if not iteration_node_config:
raise ValueError("iteration node id not found in workflow graph")
# Get node class
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
node_cls = node_classes.get(node_type)
node_cls = cast(type[BaseNode], node_cls)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=workflow.environment_variables,
)
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=iteration_node_config
)
except NotImplementedError:
variable_mapping = {}
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_type=node_type,
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
)
return graph, variable_pool
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
"""
Handle event
:param workflow_entry: workflow entry
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
self._publish_event(
QueueNodeSucceededEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
text=event.chunk_content,
from_variable_selector=event.from_variable_selector,
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
)
)
elif isinstance(event, ParallelBranchRunStartedEvent):
self._publish_event(
QueueParallelBranchRunStartedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, ParallelBranchRunSucceededEvent):
self._publish_event(
QueueParallelBranchRunSucceededEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, ParallelBranchRunFailedEvent):
self._publish_event(
QueueParallelBranchRunFailedEvent(
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, IterationRunStartedEvent):
self._publish_event(
QueueIterationStartEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)
elif isinstance(event, IterationRunNextEvent):
self._publish_event(
QueueIterationNextEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
self._publish_event(
QueueIterationCompletedEvent(
node_execution_id=event.iteration_id,
node_id=event.iteration_node_id,
node_type=event.iteration_node_type,
node_data=event.iteration_node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
outputs=event.outputs,
metadata=event.metadata,
steps=event.steps,
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@@ -1,10 +1,24 @@
from typing import Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ParallelBranchRunFailedEvent,
ParallelBranchRunStartedEvent,
ParallelBranchRunSucceededEvent,
)
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
@@ -16,138 +30,184 @@ _TEXT_COLOR_MAPPING = {
class WorkflowLoggingCallback(WorkflowCallback):
def __init__(self) -> None:
self.current_node_id = None
def on_workflow_run_started(self) -> None:
"""
Workflow run started
"""
self.print_text("\n[on_workflow_run_started]", color='pink')
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):
self.on_workflow_node_execute_started(event=event)
elif isinstance(event, NodeRunSucceededEvent):
self.on_workflow_node_execute_succeeded(event=event)
elif isinstance(event, NodeRunFailedEvent):
self.on_workflow_node_execute_failed(event=event)
elif isinstance(event, NodeRunStreamChunkEvent):
self.on_node_text_chunk(event=event)
elif isinstance(event, ParallelBranchRunStartedEvent):
self.on_workflow_parallel_started(event=event)
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
self.on_workflow_parallel_completed(event=event)
elif isinstance(event, IterationRunStartedEvent):
self.on_workflow_iteration_started(event=event)
elif isinstance(event, IterationRunNextEvent):
self.on_workflow_iteration_next(event=event)
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
self.on_workflow_iteration_completed(event=event)
else:
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
def on_workflow_run_succeeded(self) -> None:
"""
Workflow run succeeded
"""
self.print_text("\n[on_workflow_run_succeeded]", color='green')
def on_workflow_run_failed(self, error: str) -> None:
"""
Workflow run failed
"""
self.print_text("\n[on_workflow_run_failed]", color='red')
def on_workflow_node_execute_started(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> None:
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
"""
Workflow node execute started
"""
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
self.print_text(f"Node ID: {node_id}", color='yellow')
self.print_text(f"Type: {node_type.value}", color='yellow')
self.print_text(f"Index: {node_run_index}", color='yellow')
if predecessor_node_id:
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
self.print_text(f"Node ID: {event.node_id}", color="yellow")
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
self.print_text(f"Type: {event.node_type.value}", color="yellow")
def on_workflow_node_execute_succeeded(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> None:
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
"""
Workflow node execute succeeded
"""
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
self.print_text(f"Node ID: {node_id}", color='green')
self.print_text(f"Type: {node_type.value}", color='green')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
color='green')
route_node_state = event.route_node_state
def on_workflow_node_execute_failed(self, node_id: str,
node_type: NodeType,
node_data: BaseNodeData,
error: str,
inputs: Optional[dict] = None,
outputs: Optional[dict] = None,
process_data: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunSucceededEvent]", color="green")
self.print_text(f"Node ID: {event.node_id}", color="green")
self.print_text(f"Node Title: {event.node_data.title}", color="green")
self.print_text(f"Type: {event.node_type.value}", color="green")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="green",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
color="green",
)
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
color="green",
)
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
"""
Workflow node execute failed
"""
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
self.print_text(f"Node ID: {node_id}", color='red')
self.print_text(f"Type: {node_type.value}", color='red')
self.print_text(f"Error: {error}", color='red')
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
route_node_state = event.route_node_state
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
self.print_text("\n[NodeRunFailedEvent]", color="red")
self.print_text(f"Node ID: {event.node_id}", color="red")
self.print_text(f"Node Title: {event.node_data.title}", color="red")
self.print_text(f"Type: {event.node_type.value}", color="red")
if route_node_state.node_run_result:
node_run_result = route_node_state.node_run_result
self.print_text(f"Error: {node_run_result.error}", color="red")
self.print_text(
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
)
self.print_text(
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
color="red",
)
self.print_text(
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
)
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
"""
Publish text chunk
"""
if not self.current_node_id or self.current_node_id != node_id:
self.current_node_id = node_id
self.print_text('\n[on_node_text_chunk]')
self.print_text(f"Node ID: {node_id}")
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
route_node_state = event.route_node_state
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
self.current_node_id = route_node_state.node_id
self.print_text("\n[NodeRunStreamChunkEvent]")
self.print_text(f"Node ID: {route_node_state.node_id}")
self.print_text(text, color="pink", end="")
node_run_result = route_node_state.node_run_result
if node_run_result:
self.print_text(
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
)
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
self.print_text(event.chunk_content, color="pink", end="")
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
"""
Publish parallel started
"""
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
def on_workflow_parallel_completed(
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
) -> None:
"""
Publish parallel completed
"""
if isinstance(event, ParallelBranchRunSucceededEvent):
color = "blue"
elif isinstance(event, ParallelBranchRunFailedEvent):
color = "red"
self.print_text(
"\n[ParallelBranchRunSucceededEvent]"
if isinstance(event, ParallelBranchRunSucceededEvent)
else "\n[ParallelBranchRunFailedEvent]",
color=color,
)
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
if event.in_iteration_id:
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
if isinstance(event, ParallelBranchRunFailedEvent):
self.print_text(f"Error: {event.error}", color=color)
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
"""
Publish iteration started
"""
self.print_text("\n[on_workflow_iteration_started]", color='blue')
self.print_text(f"Node ID: {node_id}", color='blue')
self.print_text("\n[IterationRunStartedEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[dict]) -> None:
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
"""
Publish iteration next
"""
self.print_text("\n[on_workflow_iteration_next]", color='blue')
self.print_text("\n[IterationRunNextEvent]", color="blue")
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
self.print_text(f"Iteration Index: {event.index}", color="blue")
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
"""
Publish iteration completed
"""
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
self.print_text(
"\n[IterationRunSucceededEvent]"
if isinstance(event, IterationRunSucceededEvent)
else "\n[IterationRunFailedEvent]",
color="blue",
)
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
def on_event(self, event: AppQueueEvent) -> None:
"""
Publish event
"""
self.print_text("\n[on_workflow_event]", color='blue')
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
def print_text(
self, text: str, color: Optional[str] = None, end: str = "\n"
) -> None:
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
text_to_print = self._get_colored_text(text, color) if color else text
print(f'{text_to_print}', end=end)
print(f"{text_to_print}", end=end)
def _get_colored_text(self, text: str, color: str) -> str:
"""Get colored text."""

View File

@@ -15,13 +15,14 @@ class InvokeFrom(Enum):
"""
Invoke From.
"""
SERVICE_API = 'service-api'
WEB_APP = 'web-app'
EXPLORE = 'explore'
DEBUGGER = 'debugger'
SERVICE_API = "service-api"
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
@classmethod
def value_of(cls, value: str) -> 'InvokeFrom':
def value_of(cls, value: str) -> "InvokeFrom":
"""
Get value of given mode.
@@ -31,7 +32,7 @@ class InvokeFrom(Enum):
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid invoke from value {value}')
raise ValueError(f"invalid invoke from value {value}")
def to_source(self) -> str:
"""
@@ -40,21 +41,22 @@ class InvokeFrom(Enum):
:return: source
"""
if self == InvokeFrom.WEB_APP:
return 'web_app'
return "web_app"
elif self == InvokeFrom.DEBUGGER:
return 'dev'
return "dev"
elif self == InvokeFrom.EXPLORE:
return 'explore_app'
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return 'api'
return "api"
return 'dev'
return "dev"
class ModelConfigWithCredentialsEntity(BaseModel):
"""
Model Config With Credentials Entity.
"""
provider: str
model: str
model_schema: AIModelEntity
@@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
"""
App Generate Entity.
"""
task_id: str
# app config
@@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
"""
Chat Application Generate Entity.
"""
# app config
app_config: EasyUIBasedAppConfig
model_conf: ModelConfigWithCredentialsEntity
@@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
@@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Completion Application Generate Entity.
"""
pass
@@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
"""
Agent Chat Application Generate Entity.
"""
conversation_id: Optional[str] = None
@@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Advanced Chat Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
@@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
Workflow Application Generate Entity.
"""
# app config
app_config: WorkflowUIBasedAppConfig
@@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from enum import Enum
from typing import Any, Optional
@@ -5,13 +6,15 @@ from pydantic import BaseModel, field_validator
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class QueueEvent(str, Enum):
"""
QueueEvent enum
"""
LLM_CHUNK = "llm_chunk"
TEXT_CHUNK = "text_chunk"
AGENT_MESSAGE = "agent_message"
@@ -31,6 +34,9 @@ class QueueEvent(str, Enum):
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
MESSAGE_FILE = "message_file"
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
ERROR = "error"
PING = "ping"
STOP = "stop"
@@ -38,46 +44,73 @@ class QueueEvent(str, Enum):
class AppQueueEvent(BaseModel):
"""
QueueEvent entity
QueueEvent abstract entity
"""
event: QueueEvent
class QueueLLMChunkEvent(AppQueueEvent):
"""
QueueLLMChunkEvent entity
Only for basic mode apps
"""
event: QueueEvent = QueueEvent.LLM_CHUNK
chunk: LLMResultChunk
class QueueIterationStartEvent(AppQueueEvent):
"""
QueueIterationStartEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_START
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: dict = None
inputs: Optional[dict[str, Any]] = None
predecessor_node_id: Optional[str] = None
metadata: Optional[dict] = None
metadata: Optional[dict[str, Any]] = None
class QueueIterationNextEvent(AppQueueEvent):
"""
QueueIterationNextEvent entity
"""
event: QueueEvent = QueueEvent.ITERATION_NEXT
index: int
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration
output: Optional[Any] = None # output for the current iteration
@field_validator('output', mode='before')
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
@@ -87,41 +120,66 @@ class QueueIterationNextEvent(AppQueueEvent):
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError('output must be a valid type')
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""
QueueIterationCompletedEvent entity
"""
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
outputs: dict
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
steps: int = 0
error: Optional[str] = None
class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
"""
event: QueueEvent = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None
from_variable_selector: Optional[list[str]] = None
"""from variable selector"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAgentMessageEvent(AppQueueEvent):
"""
QueueMessageEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_MESSAGE
chunk: LLMResultChunk
class QueueMessageReplaceEvent(AppQueueEvent):
"""
QueueMessageReplaceEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
@@ -130,14 +188,18 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""
QueueRetrieverResourcesEvent entity
"""
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
retriever_resources: list[dict]
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueAnnotationReplyEvent(AppQueueEvent):
"""
QueueAnnotationReplyEvent entity
"""
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
message_annotation_id: str
@@ -146,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
"""
QueueMessageEndEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_END
llm_result: Optional[LLMResult] = None
@@ -154,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
@@ -161,20 +225,25 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
"""
QueueWorkflowSucceededEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
outputs: Optional[dict[str, Any]] = None
class QueueWorkflowFailedEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
@@ -183,29 +252,55 @@ class QueueNodeStartedEvent(AppQueueEvent):
"""
QueueNodeStartedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_STARTED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
node_run_index: int = 1
predecessor_node_id: Optional[str] = None
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
class QueueNodeSucceededEvent(AppQueueEvent):
"""
QueueNodeSucceededEvent entity
"""
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
execution_metadata: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: Optional[str] = None
@@ -214,15 +309,28 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
"""
event: QueueEvent = QueueEvent.NODE_FAILED
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict] = None
outputs: Optional[dict] = None
process_data: Optional[dict] = None
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
error: str
@@ -231,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event: QueueEvent = QueueEvent.AGENT_THOUGHT
agent_thought_id: str
@@ -239,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
"""
QueueAgentThoughtEvent entity
"""
event: QueueEvent = QueueEvent.MESSAGE_FILE
message_file_id: str
@@ -247,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
"""
QueueErrorEvent entity
"""
event: QueueEvent = QueueEvent.ERROR
error: Any = None
@@ -255,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
"""
QueuePingEvent entity
"""
event: QueueEvent = QueueEvent.PING
@@ -262,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
"""
QueueStopEvent entity
"""
class StopBy(Enum):
"""
Stop by enum
"""
USER_MANUAL = "user-manual"
ANNOTATION_REPLY = "annotation-reply"
OUTPUT_MODERATION = "output-moderation"
@@ -274,11 +388,25 @@ class QueueStopEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.STOP
stopped_by: StopBy
def get_stop_reason(self) -> str:
"""
To stop reason
"""
reason_mapping = {
QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
}
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueMessage(BaseModel):
"""
QueueMessage entity
QueueMessage abstract entity
"""
task_id: str
app_mode: str
event: AppQueueEvent
@@ -288,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
"""
MessageQueueMessage entity
"""
message_id: str
conversation_id: str
@@ -296,4 +425,57 @@ class WorkflowQueueMessage(QueueMessage):
"""
WorkflowQueueMessage entity
"""
pass
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
"""
QueueParallelBranchRunStartedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
"""
QueueParallelBranchRunSucceededEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
"""
QueueParallelBranchRunFailedEvent entity
"""
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
parallel_id: str
parallel_start_node_id: str
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
error: str

View File

@@ -3,44 +3,16 @@ from typing import Any, Optional
from pydantic import BaseModel, ConfigDict
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus
class WorkflowStreamGenerateNodes(BaseModel):
"""
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class NodeExecutionInfo(BaseModel):
"""
NodeExecutionInfo entity
"""
workflow_node_execution_id: str
node_type: NodeType
start_at: float
class TaskState(BaseModel):
"""
TaskState entity
"""
metadata: dict = {}
@@ -48,6 +20,7 @@ class EasyUITaskState(TaskState):
"""
EasyUITaskState entity
"""
llm_result: LLMResult
@@ -55,34 +28,15 @@ class WorkflowTaskState(TaskState):
"""
WorkflowTaskState entity
"""
answer: str = ""
workflow_run_id: Optional[str] = None
start_at: Optional[float] = None
total_tokens: int = 0
total_steps: int = 0
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
iteration_nested_node_ids: list[str] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
AdvancedChatTaskState entity
"""
usage: LLMUsage
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):
"""
Stream event
"""
PING = "ping"
ERROR = "error"
MESSAGE = "message"
@@ -97,6 +51,8 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
@@ -108,6 +64,7 @@ class StreamResponse(BaseModel):
"""
StreamResponse entity
"""
event: StreamEvent
task_id: str
@@ -119,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
"""
ErrorStreamResponse entity
"""
event: StreamEvent = StreamEvent.ERROR
err: Exception
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -128,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE
id: str
answer: str
@@ -137,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE
audio: str
@@ -145,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
"""
MessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
audio: str
@@ -153,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
"""
MessageEndStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_END
id: str
metadata: dict = {}
@@ -162,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
"""
MessageFileStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_FILE
id: str
type: str
@@ -173,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
"""
MessageReplaceStreamResponse entity
"""
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str
@@ -181,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
"""
AgentThoughtStreamResponse entity
"""
event: StreamEvent = StreamEvent.AGENT_THOUGHT
id: str
position: int
@@ -196,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
"""
AgentMessageStreamResponse entity
"""
event: StreamEvent = StreamEvent.AGENT_MESSAGE
id: str
answer: str
@@ -210,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
workflow_id: str
sequence_number: int
@@ -230,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
workflow_id: str
sequence_number: int
@@ -258,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@@ -267,6 +236,11 @@ class NodeStartStreamResponse(StreamResponse):
inputs: Optional[dict] = None
created_at: int
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
@@ -286,8 +260,13 @@ class NodeStartStreamResponse(StreamResponse):
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"created_at": self.data.created_at,
"extras": {}
}
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
},
}
@@ -300,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@@ -316,6 +296,11 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Optional[list[dict]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@@ -342,11 +327,62 @@ class NodeFinishStreamResponse(StreamResponse):
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": []
}
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
workflow_run_id: str
data: Data
class ParallelBranchFinishedStreamResponse(StreamResponse):
"""
ParallelBranchFinishedStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
parallel_id: str
parallel_branch_id: str
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
status: str
error: Optional[str] = None
created_at: int
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
workflow_run_id: str
data: Data
class IterationNodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@@ -356,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@@ -364,6 +401,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
extras: dict = {}
metadata: dict = {}
inputs: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_STARTED
workflow_run_id: str
@@ -379,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
@@ -387,6 +427,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
created_at: int
pre_iteration_output: Optional[Any] = None
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@@ -402,14 +444,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
outputs: Optional[dict] = None
created_at: int
extras: dict = None
inputs: dict = None
extras: Optional[dict] = None
inputs: Optional[dict] = None
status: WorkflowNodeExecutionStatus
error: Optional[str] = None
elapsed_time: float
@@ -417,6 +460,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None
finished_at: int
steps: int
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
workflow_run_id: str
@@ -432,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
"""
Data entity
"""
text: str
event: StreamEvent = StreamEvent.TEXT_CHUNK
@@ -447,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
"""
Data entity
"""
text: str
event: StreamEvent = StreamEvent.TEXT_REPLACE
@@ -457,6 +504,7 @@ class PingStreamResponse(StreamResponse):
"""
PingStreamResponse entity
"""
event: StreamEvent = StreamEvent.PING
@@ -464,6 +512,7 @@ class AppStreamResponse(BaseModel):
"""
AppStreamResponse entity
"""
stream_response: StreamResponse
@@ -471,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
"""
ChatbotAppStreamResponse entity
"""
conversation_id: str
message_id: str
created_at: int
@@ -480,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
"""
CompletionAppStreamResponse entity
"""
message_id: str
created_at: int
@@ -488,13 +539,15 @@ class WorkflowAppStreamResponse(AppStreamResponse):
"""
WorkflowAppStreamResponse entity
"""
workflow_run_id: str
workflow_run_id: Optional[str] = None
class AppBlockingResponse(BaseModel):
"""
AppBlockingResponse entity
"""
task_id: str
def to_dict(self) -> dict:
@@ -510,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
mode: str
conversation_id: str
@@ -530,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
mode: str
message_id: str
@@ -549,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
"""
Data entity
"""
id: str
workflow_id: str
status: str
@@ -562,25 +618,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
workflow_run_id: str
data: Data
class WorkflowIterationState(BaseModel):
"""
WorkflowIterationState entity
"""
class Data(BaseModel):
"""
Data entity
"""
parent_iteration_id: Optional[str] = None
iteration_id: str
current_index: int
iteration_steps_boundary: list[int] = None
node_execution_id: str
started_at: float
inputs: dict = None
total_tokens: int = 0
node_data: BaseNodeData
current_iterations: dict[str, Data] = None

View File

@@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
class AnnotationReplyFeature:
def query(self, app_record: App,
message: Message,
query: str,
user_id: str,
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
def query(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
@@ -27,8 +25,9 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = db.session.query(AppAnnotationSetting).filter(
AppAnnotationSetting.app_id == app_record.id).first()
annotation_setting = (
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
)
if not annotation_setting:
return None
@@ -41,55 +40,50 @@ class AnnotationReplyFeature:
embedding_model_name = collection_binding_detail.model_name
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name,
embedding_model_name,
'annotation'
embedding_provider_name, embedding_model_name, "annotation"
)
dataset = Dataset(
id=app_record.id,
tenant_id=app_record.tenant_id,
indexing_technique='high_quality',
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
documents = vector.search_by_vector(
query=query,
top_k=1,
score_threshold=score_threshold,
filter={
'group_id': [dataset.id]
}
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
)
if documents:
annotation_id = documents[0].metadata['annotation_id']
score = documents[0].metadata['score']
annotation_id = documents[0].metadata["annotation_id"]
score = documents[0].metadata["score"]
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
if annotation:
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
from_source = 'api'
from_source = "api"
else:
from_source = 'console'
from_source = "console"
# insert annotation history
AppAnnotationService.add_annotation_history(annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score)
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.content,
query,
user_id,
message.id,
from_source,
score,
)
return annotation
except Exception as e:
logger.warning(f'Query annotation failed, exception: {str(e)}.')
logger.warning(f"Query annotation failed, exception: {str(e)}.")
return None
return None

View File

@@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
class HostingModerationFeature:
def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
prompt_messages: list[PromptMessage]) -> bool:
def check(
self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
) -> bool:
"""
Check hosting moderation
:param application_generate_entity: application generate entity
@@ -23,9 +24,6 @@ class HostingModerationFeature:
if isinstance(prompt_message.content, str):
text += prompt_message.content + "\n"
moderation_result = moderation.check_moderation(
model_config,
text
)
moderation_result = moderation.check_moderation(model_config, text)
return moderation_result

View File

@@ -19,7 +19,7 @@ class RateLimit:
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
_instance_dict = {}
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
if client_id not in cls._instance_dict:
instance = super().__new__(cls)
cls._instance_dict[client_id] = instance
@@ -27,13 +27,13 @@ class RateLimit:
def __init__(self, client_id: str, max_active_requests: int):
self.max_active_requests = max_active_requests
if hasattr(self, 'initialized'):
if hasattr(self, "initialized"):
return
self.initialized = True
self.client_id = client_id
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
self.last_recalculate_time = float('-inf')
self.last_recalculate_time = float("-inf")
self.flush_cache(use_local_value=True)
def flush_cache(self, use_local_value=False):
@@ -46,7 +46,7 @@ class RateLimit:
pipe.execute()
else:
with redis_client.pipeline() as pipe:
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
# flush max active requests (in-transit request list)
@@ -54,8 +54,11 @@ class RateLimit:
return
request_details = redis_client.hgetall(self.active_requests_key)
redis_client.expire(self.active_requests_key, timedelta(days=1))
timeout_requests = [k for k, v in request_details.items() if
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
timeout_requests = [
k
for k, v in request_details.items()
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
]
if timeout_requests:
redis_client.hdel(self.active_requests_key, *timeout_requests)
@@ -69,8 +72,10 @@ class RateLimit:
active_requests_count = redis_client.hlen(self.active_requests_key)
if active_requests_count >= self.max_active_requests:
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests))
raise AppInvokeQuotaExceededError(
"Too many requests. Please try again later. The current maximum "
"concurrent requests allowed is {}.".format(self.max_active_requests)
)
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
return request_id
@@ -116,5 +121,5 @@ class RateLimitGenerator:
if not self.closed:
self.closed = True
self.rate_limit.exit(self.request_id)
if self.generator is not None and hasattr(self.generator, 'close'):
if self.generator is not None and hasattr(self.generator, "close"):
self.generator.close()

View File

@@ -25,25 +25,25 @@ from .variables import (
)
__all__ = [
'IntegerVariable',
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
'SegmentType',
'SegmentGroup',
'Segment',
'NoneSegment',
'NoneVariable',
'IntegerSegment',
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArraySegment',
"IntegerVariable",
"FloatVariable",
"ObjectVariable",
"SecretVariable",
"StringVariable",
"ArrayAnyVariable",
"Variable",
"SegmentType",
"SegmentGroup",
"Segment",
"NoneSegment",
"NoneVariable",
"IntegerSegment",
"FloatSegment",
"ObjectSegment",
"ArrayAnySegment",
"StringSegment",
"ArrayStringVariable",
"ArrayNumberVariable",
"ArrayObjectVariable",
"ArraySegment",
]

View File

@@ -28,12 +28,12 @@ from .variables import (
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
if (value_type := mapping.get('value_type')) is None:
raise VariableError('missing value type')
if not mapping.get('name'):
raise VariableError('missing name')
if (value := mapping.get('value')) is None:
raise VariableError('missing value')
if (value_type := mapping.get("value_type")) is None:
raise VariableError("missing value type")
if not mapping.get("name"):
raise VariableError("missing name")
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.NUMBER if isinstance(value, float):
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
raise VariableError(f"invalid number value {value}")
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
@@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
raise VariableError(f"not supported value type {value_type}")
if result.size > dify_config.MAX_VARIABLE_SIZE:
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
return result
@@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
raise ValueError(f'not supported value {value}')
raise ValueError(f"not supported value {value}")

View File

@@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
from . import SegmentGroup, factory
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
def convert_template(*, template: str, variable_pool: VariablePool):
parts = re.split(VARIABLE_PATTERN, template)
segments = []
for part in filter(lambda x: x, parts):
if '.' in part and (value := variable_pool.get(part.split('.'))):
if "." in part and (value := variable_pool.get(part.split("."))):
segments.append(value)
else:
segments.append(factory.build_segment(part))

View File

@@ -8,15 +8,15 @@ class SegmentGroup(Segment):
@property
def text(self):
return ''.join([segment.text for segment in self.value])
return "".join([segment.text for segment in self.value])
@property
def log(self):
return ''.join([segment.log for segment in self.value])
return "".join([segment.log for segment in self.value])
@property
def markdown(self):
return ''.join([segment.markdown for segment in self.value])
return "".join([segment.markdown for segment in self.value])
def to_object(self):
return [segment.to_object() for segment in self.value]

View File

@@ -14,13 +14,13 @@ class Segment(BaseModel):
value_type: SegmentType
value: Any
@field_validator('value_type')
@field_validator("value_type")
def validate_value_type(cls, value):
"""
This validator checks if the provided value is equal to the default value of the 'value_type' field.
If the value is different, a ValueError is raised.
"""
if value != cls.model_fields['value_type'].default:
if value != cls.model_fields["value_type"].default:
raise ValueError("Cannot modify 'value_type'")
return value
@@ -50,15 +50,15 @@ class NoneSegment(Segment):
@property
def text(self) -> str:
return 'null'
return "null"
@property
def log(self) -> str:
return 'null'
return "null"
@property
def markdown(self) -> str:
return 'null'
return "null"
class StringSegment(Segment):
@@ -76,24 +76,21 @@ class IntegerSegment(Segment):
value: int
class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any]
@property
def text(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
@property
def log(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
@property
def markdown(self) -> str:
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
class ArraySegment(Segment):
@@ -101,11 +98,11 @@ class ArraySegment(Segment):
def markdown(self) -> str:
items = []
for item in self.value:
if hasattr(item, 'to_markdown'):
if hasattr(item, "to_markdown"):
items.append(item.to_markdown())
else:
items.append(str(item))
return '\n'.join(items)
return "\n".join(items)
class ArrayAnySegment(ArraySegment):
@@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]

View File

@@ -2,14 +2,14 @@ from enum import Enum
class SegmentType(str, Enum):
NONE = 'none'
NUMBER = 'number'
STRING = 'string'
SECRET = 'secret'
ARRAY_ANY = 'array[any]'
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
OBJECT = 'object'
NONE = "none"
NUMBER = "number"
STRING = "string"
SECRET = "secret"
ARRAY_ANY = "array[any]"
ARRAY_STRING = "array[string]"
ARRAY_NUMBER = "array[number]"
ARRAY_OBJECT = "array[object]"
OBJECT = "object"
GROUP = 'group'
GROUP = "group"

View File

@@ -23,11 +23,11 @@ class Variable(Segment):
"""
id: str = Field(
default='',
default="",
description="Unique identity for variable. It's only used by environment variables now.",
)
name: str
description: str = Field(default='', description='Description of the variable.')
description: str = Field(default="", description="Description of the variable.")
class StringVariable(StringSegment, Variable):
@@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass
class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET

View File

@@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(self, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool) -> None:
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@@ -61,35 +64,39 @@ class BasedGenerateTaskPipeline:
e = event.error
if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError('Incorrect API key provided')
err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
err = e
else:
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message:
message = db.session.query(Message).filter(Message.id == message.id).first()
err_desc = self._error_to_desc(err)
message.status = 'error'
message.error = err_desc
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
db.session.commit()
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = "error"
refetch_message.error = err_desc
db.session.commit()
return err
def _error_to_desc(cls, e: Exception) -> str:
def _error_to_desc(self, e: Exception) -> str:
"""
Error to desc.
:param e: exception
:return:
"""
if isinstance(e, QuotaExceededError):
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials.")
return (
"Your quota for Dify Hosted Model Provider has been exhausted. "
"Please go to Settings -> Model Provider to complete your own provider credentials."
)
message = getattr(e, 'description', str(e))
message = getattr(e, "description", str(e))
if not message:
message = 'Internal Server Error, please contact support.'
message = "Internal Server Error, please contact support."
return message
@@ -99,10 +106,7 @@ class BasedGenerateTaskPipeline:
:param e: exception
:return:
"""
return ErrorStreamResponse(
task_id=self._application_generate_entity.task_id,
err=e
)
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
def _ping_stream_response(self) -> PingStreamResponse:
"""
@@ -123,11 +127,8 @@ class BasedGenerateTaskPipeline:
return OutputModeration(
tenant_id=app_config.tenant_id,
app_id=app_config.app_id,
rule=ModerationRule(
type=sensitive_word_avoidance.type,
config=sensitive_word_avoidance.config
),
queue_manager=self._queue_manager
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
queue_manager=self._queue_manager,
)
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
@@ -141,8 +142,7 @@ class BasedGenerateTaskPipeline:
self._output_moderation_handler.stop_thread()
completion = self._output_moderation_handler.moderation_completion(
completion=completion,
public_event=False
completion=completion, public_event=False
)
self._output_moderation_handler = None

View File

@@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
"""
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: EasyUITaskState
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
]
def __init__(self, application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool) -> None:
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
def __init__(
self,
application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
],
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
@@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model=self._model_config.model,
prompt_messages=[],
message=AssistantPromptMessage(content=""),
usage=LLMUsage.empty_usage()
usage=LLMUsage.empty_usage(),
)
)
self._conversation_name_generate_thread = None
def process(
self,
self,
) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]:
"""
Process generate task pipeline.
@@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation,
self._application_generate_entity.query
self._conversation, self._application_generate_entity.query
)
generator = self._wrapper_process_stream_response(
trace_manager=self._application_generate_entity.trace_manager
)
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream:
return self._to_stream_response(generator)
else:
return self._to_blocking_response(generator)
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
ChatbotAppBlockingResponse,
CompletionAppBlockingResponse
]:
def _to_blocking_response(
self, generator: Generator[StreamResponse, None, None]
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
"""
Process blocking response.
:return:
@@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, MessageEndStreamResponse):
extras = {
'usage': jsonable_encoder(self._task_state.llm_result.usage)
}
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
if self._conversation.mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse(
@@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
else:
response = ChatbotAppBlockingResponse(
@@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
message_id=self._message.id,
answer=self._task_state.llm_result.message.content,
created_at=int(self._message.created_at.timestamp()),
**extras
)
**extras,
),
)
return response
else:
continue
raise Exception('Queue listening stopped unexpectedly.')
raise Exception("Queue listening stopped unexpectedly.")
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
def _to_stream_response(
self, generator: Generator[StreamResponse, None, None]
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
"""
To stream response.
:return:
@@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
yield CompletionAppStreamResponse(
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
else:
yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id,
message_id=self._message.id,
created_at=int(self._message.created_at.timestamp()),
stream_response=stream_response
stream_response=stream_response,
)
def _listenAudioMsg(self, publisher, task_id: str):
@@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
Generator[StreamResponse, None, None]:
def _wrapper_process_stream_response(
self, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
tenant_id = self._application_generate_entity.app_config.tenant_id
task_id = self._application_generate_entity.task_id
publisher = None
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
if (
text_to_speech_dict
and text_to_speech_dict.get("autoPlay") == "enabled"
and text_to_speech_dict.get("enabled")
):
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True:
audio_response = self._listenAudioMsg(publisher, task_id)
@@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
break
else:
start_listener_time = time.time()
yield MessageAudioStreamResponse(audio=audio.audio,
task_id=task_id)
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
def _process_stream_response(
self,
publisher: AppGeneratorTTSPublisher,
trace_manager: Optional[TraceQueueManager] = None
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
) -> Generator[StreamResponse, None, None]:
"""
Process stream response.
@@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join()
def _save_message(
self, trace_manager: Optional[TraceQueueManager] = None
) -> None:
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
"""
Save message.
:return:
@@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode,
self._task_state.llm_result.prompt_messages
self._model_config.mode, self._task_state.llm_result.prompt_messages
)
self._message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
if llm_result.message.content else ''
self._message.answer = (
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
if llm_result.message.content
else ""
)
self._message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price
self._message.currency = usage.currency
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
if self._task_state.metadata else None
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
)
db.session.commit()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE,
conversation_id=self._conversation.id,
message_id=self._message.id
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
)
)
@@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._message,
application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in [
AppMode.AGENT_CHAT,
AppMode.CHAT
] and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
)
def _handle_stop(self, event: QueueStopEvent) -> None:
@@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model = model_config.model
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle,
model=model_config.model
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
# calculate num tokens
prompt_tokens = 0
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
prompt_tokens = model_instance.get_llm_num_tokens(
self._task_state.llm_result.prompt_messages
)
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
completion_tokens = 0
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
completion_tokens = model_instance.get_llm_num_tokens(
[self._task_state.llm_result.message]
)
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
credentials = model_config.credentials
@@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
model,
credentials,
prompt_tokens,
completion_tokens
model, credentials, prompt_tokens, completion_tokens
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
Message end to stream response.
:return:
"""
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
extras = {}
if self._task_state.metadata:
extras['metadata'] = self._task_state.metadata
extras["metadata"] = self._task_state.metadata
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message.id,
**extras
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
)
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
@@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
return AgentMessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
)
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
@@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
:return:
"""
agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
)
db.session.refresh(agent_thought)
db.session.close()
@@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
tool=agent_thought.tool,
tool_labels=agent_thought.tool_labels,
tool_input=agent_thought.tool_input,
message_files=agent_thought.files
message_files=agent_thought.files,
)
return None
@@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
prompt_messages=self._task_state.llm_result.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
)
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
),
)
), PublishFrom.TASK_PIPELINE
),
PublishFrom.TASK_PIPELINE,
)
self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
)
return True
else:

View File

@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
@@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
AdvancedChatTaskState,
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
@@ -31,12 +30,9 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
"""
@@ -45,17 +41,23 @@ class MessageCycleManage:
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
'flask_app': current_app._get_current_object(),
'conversation_id': conversation.id,
'query': query
})
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation.id,
"query": query,
},
)
thread.start()
@@ -63,17 +65,13 @@ class MessageCycleManage:
return None
def _generate_conversation_name_worker(self,
flask_app: Flask,
conversation_id: str,
query: str):
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = (
db.session.query(Conversation)
.filter(Conversation.id == conversation_id)
.first()
)
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
app_model = conversation.app
@@ -100,12 +98,9 @@ class MessageCycleManage:
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata['annotation_reply'] = {
'id': annotation.id,
'account': {
'id': annotation.account_id,
'name': account.name if account else 'Dify user'
}
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
return annotation
@@ -119,28 +114,7 @@ class MessageCycleManage:
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata['retriever_resources'] = event.retriever_resources
def _get_response_metadata(self) -> dict:
"""
Get response metadata by invoke from.
:return:
"""
metadata = {}
# show_retrieve_source
if 'retriever_resources' in self._task_state.metadata:
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
# show annotation reply
if 'annotation_reply' in self._task_state.metadata:
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
# show usage
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
metadata['usage'] = self._task_state.metadata['usage']
return metadata
self._task_state.metadata["retriever_resources"] = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
@@ -148,27 +122,23 @@ class MessageCycleManage:
:param event: event
:return:
"""
message_file: MessageFile = (
db.session.query(MessageFile)
.filter(MessageFile.id == event.message_file_id)
.first()
)
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file:
# get tool file id
tool_file_id = message_file.url.split('/')[-1]
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split('.')[0]
tool_file_id = tool_file_id.split(".")[0]
# get extension
if '.' in message_file.url:
if "." in message_file.url:
extension = f'.{message_file.url.split(".")[-1]}'
if len(extension) > 10:
extension = '.bin'
extension = ".bin"
else:
extension = '.bin'
extension = ".bin"
# add sign url to local file
if message_file.url.startswith('http'):
if message_file.url.startswith("http"):
url = message_file.url
else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
@@ -177,8 +147,8 @@ class MessageCycleManage:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or 'user',
url=url
belongs_to=message_file.belongs_to or "user",
url=url,
)
return None
@@ -190,11 +160,7 @@ class MessageCycleManage:
:param message_id: message id
:return:
"""
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer
)
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
"""
@@ -202,7 +168,4 @@ class MessageCycleManage:
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(
task_id=self._application_generate_entity.task_id,
answer=answer
)
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)

View File

@@ -1,33 +1,41 @@
import json
import time
from datetime import datetime, timezone
from typing import Optional, Union, cast
from typing import Any, Optional, Union, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
NodeExecutionInfo,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
from core.file.file_obj import FileVar
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@@ -41,54 +49,56 @@ from models.workflow import (
WorkflowRunStatus,
WorkflowRunTriggeredFrom,
)
from services.workflow_service import WorkflowService
class WorkflowCycleManage(WorkflowIterationCycleManage):
def _init_workflow_run(self, workflow: Workflow,
triggered_from: WorkflowRunTriggeredFrom,
user: Union[Account, EndUser],
user_inputs: dict,
system_inputs: Optional[dict] = None) -> WorkflowRun:
"""
Init workflow run
:param workflow: Workflow instance
:param triggered_from: triggered from
:param user: account or end user
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:return:
"""
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
.filter(WorkflowRun.app_id == workflow.app_id) \
.scalar() or 0
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def _handle_workflow_run_start(self) -> WorkflowRun:
max_sequence = (
db.session.query(db.func.max(WorkflowRun.sequence_number))
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
.filter(WorkflowRun.app_id == self._workflow.app_id)
.scalar()
or 0
)
new_sequence_number = max_sequence + 1
inputs = {**user_inputs}
for key, value in (system_inputs or {}).items():
if key.value == 'conversation':
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
continue
inputs[f'sys.{key.value}'] = value
inputs = WorkflowEngineManager.handle_special_values(inputs)
inputs[f"sys.{key.value}"] = value
inputs = WorkflowEntry.handle_special_values(inputs)
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# init workflow run
workflow_run = WorkflowRun(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
sequence_number=new_sequence_number,
workflow_id=workflow.id,
type=workflow.type,
triggered_from=triggered_from.value,
version=workflow.version,
graph=workflow.graph,
inputs=json.dumps(inputs),
status=WorkflowRunStatus.RUNNING.value,
created_by_role=(CreatedByRole.ACCOUNT.value
if isinstance(user, Account) else CreatedByRole.END_USER.value),
created_by=user.id
workflow_run = WorkflowRun()
workflow_run.tenant_id = self._workflow.tenant_id
workflow_run.app_id = self._workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id
workflow_run.type = self._workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version
workflow_run.graph = self._workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING.value
workflow_run.created_by_role = (
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
)
workflow_run.created_by = self._user.id
db.session.add(workflow_run)
db.session.commit()
@@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _workflow_run_success(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Optional[str] = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = outputs
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
if trace_manager:
trace_manager.add_trace_task(
@@ -135,34 +149,64 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
)
)
db.session.close()
return workflow_run
def _workflow_run_failed(
self, workflow_run: WorkflowRun,
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.commit()
running_workflow_node_executions = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
)
for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
db.session.commit()
db.session.refresh(workflow_run)
db.session.close()
@@ -178,39 +222,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_run
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
"""
Init workflow node execution from workflow run
:param workflow_run: workflow run
:param node_id: node id
:param node_type: node type
:param node_title: node title
:param node_run_index: run index
:param predecessor_node_id: predecessor node id if exists
:return:
"""
def _handle_node_execution_start(
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
# init workflow node execution
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
db.session.add(workflow_node_execution)
db.session.commit()
@@ -219,33 +250,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
"""
Workflow node execution success
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param inputs: inputs
:param process_data: process data
:param outputs: outputs
:param execution_metadata: execution metadata
:param event: queue node succeeded event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@@ -253,33 +277,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
start_at: float,
error: str,
inputs: Optional[dict] = None,
process_data: Optional[dict] = None,
outputs: Optional[dict] = None,
execution_metadata: Optional[dict] = None
) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param workflow_node_execution: workflow node execution
:param start_at: start time
:param error: error message
:param event: queue node failed event
:return:
"""
inputs = WorkflowEngineManager.handle_special_values(inputs)
outputs = WorkflowEngineManager.handle_special_values(outputs)
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.error = event.error
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
if execution_metadata else None
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
db.session.commit()
db.session.refresh(workflow_node_execution)
@@ -287,8 +302,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return workflow_node_execution
def _workflow_start_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowStartStreamResponse:
"""
Workflow start to stream response.
:param task_id: task id
@@ -302,13 +322,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=workflow_run.inputs_dict,
created_at=int(workflow_run.created_at.timestamp())
)
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(self, task_id: str,
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun
) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
@@ -348,14 +369,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
)
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
),
)
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution) \
-> NodeStartStreamResponse:
def _workflow_node_start_to_stream_response(
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
) -> Optional[NodeStartStreamResponse]:
"""
Workflow node start to stream response.
:param event: queue node started event
@@ -363,6 +383,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -374,29 +397,42 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp())
)
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras['icon'] = ToolManager.get_tool_icon(
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id
provider_id=node_data.provider_id,
)
return response
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
-> NodeFinishStreamResponse:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
@@ -416,181 +452,153 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
)
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
),
)
def _handle_workflow_start(self) -> WorkflowRun:
self._task_state.start_at = time.perf_counter()
workflow_run = self._init_workflow_run(
workflow=self._workflow,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN,
user=self._user,
user_inputs=self._application_generate_entity.inputs,
system_inputs=self._workflow_system_variables
def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
"""
Workflow parallel branch start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
created_at=int(time.time()),
),
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_node_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
predecessor_node_id=event.predecessor_node_id
def _workflow_parallel_branch_finished_to_stream_response(
self,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
"""
Workflow parallel branch finished to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=event.node_type,
start_at=time.perf_counter()
def _workflow_iteration_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
"""
Workflow iteration start to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration start event
:return:
"""
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
def _workflow_iteration_next_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
"""
Workflow iteration next to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration next event
:return:
"""
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
self._task_state.total_steps += 1
db.session.close()
return workflow_node_execution
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
if self._iteration_state and self._iteration_state.current_iterations:
if not execution_metadata:
execution_metadata = {}
current_iteration_data = None
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if data.parent_iteration_id == None:
current_iteration_data = data
break
if current_iteration_data:
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
if isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._workflow_node_execution_success(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
inputs=event.inputs,
process_data=event.process_data,
def _workflow_iteration_completed_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
"""
Workflow iteration completed to stream response
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration completed event
:return:
"""
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
execution_metadata=execution_metadata
)
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
self._task_state.total_tokens += (
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
if self._iteration_state:
for iteration_node_id in self._iteration_state.current_iterations:
data = self._iteration_state.current_iterations[iteration_node_id]
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
if workflow_node_execution.node_type == NodeType.LLM.value:
outputs = workflow_node_execution.outputs_dict
usage_dict = outputs.get('usage', {})
self._task_state.metadata['usage'] = usage_dict
else:
workflow_node_execution = self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=current_node_execution.start_at,
error=event.error,
inputs=event.inputs,
process_data=event.process_data,
outputs=event.outputs,
execution_metadata=execution_metadata
)
db.session.close()
return workflow_node_execution
def _handle_workflow_finished(
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None
) -> Optional[WorkflowRun]:
workflow_run = db.session.query(WorkflowRun).filter(
WorkflowRun.id == self._task_state.workflow_run_id).first()
if not workflow_run:
return None
if conversation_id is None:
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
if isinstance(event, QueueStopEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.STOPPED,
error='Workflow stopped.',
conversation_id=conversation_id,
trace_manager=trace_manager
)
latest_node_execution_info = self._task_state.latest_node_execution_info
if latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
if (workflow_node_execution
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
self._workflow_node_execution_failed(
workflow_node_execution=workflow_node_execution,
start_at=latest_node_execution_info.start_at,
error='Workflow stopped.'
)
elif isinstance(event, QueueWorkflowFailedEvent):
workflow_run = self._workflow_run_failed(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=conversation_id,
trace_manager=trace_manager
)
else:
if self._task_state.latest_node_execution_info:
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
outputs = workflow_node_execution.outputs
else:
outputs = None
workflow_run = self._workflow_run_success(
workflow_run=workflow_run,
total_tokens=self._task_state.total_tokens,
total_steps=self._task_state.total_steps,
outputs=outputs,
conversation_id=conversation_id,
trace_manager=trace_manager
)
self._task_state.workflow_run_id = workflow_run.id
db.session.close()
return workflow_run
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
"""
@@ -641,9 +649,45 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
return None
if isinstance(value, dict):
if '__variant' in value and value['__variant'] == FileVar.__name__:
if "__variant" in value and value["__variant"] == FileVar.__name__:
return value
elif isinstance(value, FileVar):
return value.to_dict()
return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
if not workflow_run:
raise Exception(f"Workflow run not found: {workflow_run_id}")
return workflow_run
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
"""
Refetch workflow node execution
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = (
db.session.query(WorkflowNodeExecution)
.filter(
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
WorkflowNodeExecution.workflow_id == self._workflow.id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.node_execution_id == node_execution_id,
)
.first()
)
if not workflow_node_execution:
raise Exception(f"Workflow node execution not found: {node_execution_id}")
return workflow_node_execution

View File

@@ -1,16 +0,0 @@
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.enums import SystemVariableKey
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
class WorkflowCycleStateManager:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariableKey, Any]

View File

@@ -1,290 +0,0 @@
import json
import time
from collections.abc import Generator
from datetime import datetime, timezone
from typing import Optional, Union
from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
)
from core.app.entities.task_entities import (
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeExecutionInfo,
WorkflowIterationState,
)
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
from models.workflow import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
)
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
_iteration_state: WorkflowIterationState = None
def _init_iteration_state(self) -> WorkflowIterationState:
if not self._iteration_state:
self._iteration_state = WorkflowIterationState(
current_iterations={}
)
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
"""
Handle iteration to stream response
:param task_id: task id
:param event: iteration event
:return:
"""
if isinstance(event, QueueIterationStartEvent):
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs,
metadata=event.metadata
)
)
elif isinstance(event, QueueIterationNextEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={}
)
)
elif isinstance(event, QueueIterationCompletedEvent):
current_iteration = self._iteration_state.current_iterations[event.node_id]
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=current_iteration.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
error=None,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)
def _init_iteration_execution_from_workflow_run(self,
workflow_run: WorkflowRun,
node_id: str,
node_type: NodeType,
node_title: str,
node_run_index: int = 1,
inputs: Optional[dict] = None,
predecessor_node_id: Optional[str] = None
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution(
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
workflow_run_id=workflow_run.id,
predecessor_node_id=predecessor_node_id,
index=node_run_index,
node_id=node_id,
node_type=node_type.value,
inputs=json.dumps(inputs) if inputs else None,
title=node_title,
status=WorkflowNodeExecutionStatus.RUNNING.value,
created_by_role=workflow_run.created_by_role,
created_by=workflow_run.created_by,
execution_metadata=json.dumps({
'started_run_index': node_run_index + 1,
'current_index': 0,
'steps_boundary': [],
}),
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
)
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
db.session.close()
return workflow_node_execution
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
if isinstance(event, QueueIterationStartEvent):
return self._handle_iteration_started(event)
elif isinstance(event, QueueIterationNextEvent):
return self._handle_iteration_next(event)
elif isinstance(event, QueueIterationCompletedEvent):
return self._handle_iteration_completed(event)
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
self._init_iteration_state()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
workflow_run=workflow_run,
node_id=event.node_id,
node_type=NodeType.ITERATION,
node_title=event.node_data.title,
node_run_index=event.node_run_index,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id
)
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
self._task_state.latest_node_execution_info = latest_node_execution_info
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
parent_iteration_id=None,
iteration_id=event.node_id,
current_index=0,
iteration_steps_boundary=[],
node_execution_id=workflow_node_execution.id,
started_at=time.perf_counter(),
inputs=event.inputs,
total_tokens=0,
node_data=event.node_data
)
db.session.close()
return workflow_node_execution
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
current_iteration.current_index = event.index
current_iteration.iteration_steps_boundary.append(event.node_run_index)
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['current_index'] = event.index
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
db.session.close()
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
if event.node_id not in self._iteration_state.current_iterations:
return
current_iteration = self._iteration_state.current_iterations[event.node_id]
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
if original_node_execution_metadata:
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
db.session.commit()
# remove current iteration
self._iteration_state.current_iterations.pop(event.node_id, None)
# set latest node execution info
latest_node_execution_info = NodeExecutionInfo(
workflow_node_execution_id=workflow_node_execution.id,
node_type=NodeType.ITERATION,
start_at=time.perf_counter()
)
self._task_state.latest_node_execution_info = latest_node_execution_info
db.session.close()
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
"""
Handle iteration exception
"""
if not self._iteration_state or not self._iteration_state.current_iterations:
return
for node_id, current_iteration in self._iteration_state.current_iterations.items():
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == current_iteration.node_execution_id
).first()
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
db.session.commit()
db.session.close()
yield IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=self._task_state.workflow_run_id,
data=IterationNodeCompletedStreamResponse.Data(
id=node_id,
node_id=node_id,
node_type=NodeType.ITERATION.value,
title=current_iteration.node_data.title,
outputs={},
created_at=int(time.time()),
extras={},
inputs=current_iteration.inputs,
status=WorkflowNodeExecutionStatus.FAILED,
error=error,
elapsed_time=time.perf_counter() - current_iteration.started_at,
total_tokens=current_iteration.total_tokens,
execution_metadata={
'total_tokens': current_iteration.total_tokens,
},
finished_at=int(time.time()),
steps=current_iteration.current_index
)
)

View File

@@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
"red": "31;1",
}
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
class DifyAgentCallbackHandler(BaseModel):
"""Callback Handler that prints to std out."""
color: Optional[str] = ''
color: Optional[str] = ""
current_loop: int = 1
def __init__(self, color: Optional[str] = None) -> None:
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
self.color = color or 'green'
self.color = color or "green"
self.current_loop = 1
def on_tool_start(
@@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
tool_outputs: Sequence[ToolInvokeMessage],
message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None
trace_manager: Optional[TraceQueueManager] = None,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color=self.color)
@@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
)
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
def on_agent_start(
self, thought: str
) -> None:
def on_agent_start(self, thought: str) -> None:
"""Run on agent start."""
if thought:
print_text("\n[on_agent_start] \nCurrent Loop: " + \
str(self.current_loop) + \
"\nThought: " + thought + "\n", color=self.color)
print_text(
"\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
color=self.color,
)
else:
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
def on_agent_finish(
self, color: Optional[str] = None, **kwargs: Any
) -> None:
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
"""Run on agent end."""
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
@@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"

View File

@@ -1,4 +1,3 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, queue_manager: AppQueueManager,
app_id: str,
message_id: str,
user_id: str,
invoke_from: InvokeFrom) -> None:
def __init__(
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
) -> None:
self._queue_manager = queue_manager
self._app_id = app_id
self._message_id = message_id
@@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source="app",
source_app_id=self._app_id,
created_by_role=('account'
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self._user_id
created_by_role=(
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
),
created_by=self._user_id,
)
db.session.add(dataset_query)
@@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
@@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
for item in resource:
dataset_retriever_resource = DatasetRetrieverResource(
message_id=self._message_id,
position=item.get('position'),
dataset_id=item.get('dataset_id'),
dataset_name=item.get('dataset_name'),
document_id=item.get('document_id'),
document_name=item.get('document_name'),
data_source_type=item.get('data_source_type'),
segment_id=item.get('segment_id'),
score=item.get('score') if 'score' in item else None,
hit_count=item.get('hit_count') if 'hit_count' else None,
word_count=item.get('word_count') if 'word_count' in item else None,
segment_position=item.get('segment_position') if 'segment_position' in item else None,
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
content=item.get('content'),
retriever_from=item.get('retriever_from'),
created_by=self._user_id
position=item.get("position"),
dataset_id=item.get("dataset_id"),
dataset_name=item.get("dataset_name"),
document_id=item.get("document_id"),
document_name=item.get("document_name"),
data_source_type=item.get("data_source_type"),
segment_id=item.get("segment_id"),
score=item.get("score") if "score" in item else None,
hit_count=item.get("hit_count") if "hit_count" else None,
word_count=item.get("word_count") if "word_count" in item else None,
segment_position=item.get("segment_position") if "segment_position" in item else None,
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
content=item.get("content"),
retriever_from=item.get("retriever_from"),
created_by=self._user_id,
)
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource),
PublishFrom.APPLICATION_MANAGER
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""
"""Callback Handler that prints to std out."""

View File

@@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
embedding_queue_indices = []
for i, text in enumerate(texts):
hash = helper.generate_text_hash(text)
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider).first()
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
)
.first()
)
if embedding:
text_embeddings[i] = embedding.get_embedding()
else:
@@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
embedding_queue_embeddings = []
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
self._model_instance.credentials)
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
else 1
)
for i in range(0, len(embedding_queue_texts), max_chunks):
batch_texts = embedding_queue_texts[i:i + max_chunks]
batch_texts = embedding_queue_texts[i : i + max_chunks]
embedding_result = self._model_instance.invoke_text_embedding(
texts=batch_texts,
user=self._user
)
embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
for vector in embedding_result.embeddings:
try:
@@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except Exception as e:
logging.exception('Failed transform embedding: ', e)
logging.exception("Failed transform embedding: ", e)
cache_embeddings = []
try:
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
text_embeddings[i] = embedding
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider)
embedding_cache = Embedding(
model_name=self._model_instance.model,
hash=hash,
provider_name=self._model_instance.provider,
)
embedding_cache.set_embedding(embedding)
db.session.add(embedding_cache)
cache_embeddings.append(hash)
@@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
db.session.rollback()
except Exception as ex:
db.session.rollback()
logger.error('Failed to embed documents: ', ex)
logger.error("Failed to embed documents: ", ex)
raise ex
return text_embeddings
@@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text],
user=self._user
)
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
@@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
except IntegrityError:
db.session.rollback()
except:
logging.exception('Failed to add embedding to redis')
logging.exception("Failed to add embedding to redis")
return embedding_results

Some files were not shown because too many files have changed in this diff Show More