mirror of
https://github.com/langgenius/dify.git
synced 2026-01-10 00:04:14 +00:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffd4bf8bf0 | ||
|
|
bb3002b173 | ||
|
|
d4dc54447a | ||
|
|
d109881410 | ||
|
|
d1605952b0 | ||
|
|
2cf1187b32 | ||
|
|
178730266d | ||
|
|
dabfd74622 | ||
|
|
5da0182800 | ||
|
|
ed37439ef7 | ||
|
|
af92f19291 | ||
|
|
86f7f245e4 | ||
|
|
2d690801d1 | ||
|
|
fede54be77 | ||
|
|
85ff82a694 | ||
|
|
c8df92d0eb | ||
|
|
144d30d7ef | ||
|
|
4313d92e6b | ||
|
|
0695543f63 | ||
|
|
0bec6a037c | ||
|
|
3ff9a1f24a | ||
|
|
a771eea4f6 | ||
|
|
61a0ca9e0d | ||
|
|
551b33c8e5 | ||
|
|
fa34b9aed6 | ||
|
|
bbb609179f | ||
|
|
a27d4d58ec | ||
|
|
50d92f0fd4 | ||
|
|
a15791e788 | ||
|
|
954580a4af | ||
|
|
ab7d79275e | ||
|
|
d3658166fb | ||
|
|
54b72bdd0a | ||
|
|
d28446301f | ||
|
|
9050f92e5b | ||
|
|
feefeb44d7 | ||
|
|
d542b15cc0 | ||
|
|
2d7954c7da | ||
|
|
b1918dae5e | ||
|
|
031a0b576d | ||
|
|
0cef25ef8c | ||
|
|
cdb08be951 | ||
|
|
900fd82a92 | ||
|
|
44f963f281 | ||
|
|
01858e1caf | ||
|
|
2060db8e11 | ||
|
|
9ded063417 | ||
|
|
d72da2777c | ||
|
|
89aede80cc | ||
|
|
e0d3cd91c6 | ||
|
|
1a054ac1f4 | ||
|
|
3230f4a0ec | ||
|
|
dadca0f91a | ||
|
|
d489b8b3e0 | ||
|
|
bd0992275c | ||
|
|
3e7597f2bd | ||
|
|
0e71f6db84 | ||
|
|
f6b9982c23 | ||
|
|
fb113a9479 | ||
|
|
15791510c8 | ||
|
|
0f72a8e89d | ||
|
|
14af87527f | ||
|
|
83e84865be | ||
|
|
c2a3c5a748 | ||
|
|
83494cb4f5 | ||
|
|
0bc19c3fbf | ||
|
|
571415d1a4 | ||
|
|
7b2cf8215f | ||
|
|
fee4d3f6ca | ||
|
|
161cc0cda9 | ||
|
|
71bff9fcf3 | ||
|
|
80d14c9b22 | ||
|
|
c5bdf08558 | ||
|
|
596f160a1e | ||
|
|
d8b6c053a2 | ||
|
|
4b262cae58 | ||
|
|
1a5116cba0 | ||
|
|
01581dd35f | ||
|
|
7fdd964379 | ||
|
|
0cfcc97e9d |
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: api/**
|
||||
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v44
|
||||
uses: tj-actions/changed-files@v45
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
dev/**
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v6
|
||||
uses: super-linter/super-linter/slim@v7
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
54
.github/workflows/translate-i18n-base-on-english.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
2
LICENSE
2
LICENSE
@@ -4,7 +4,7 @@ Dify is licensed under the Apache License 2.0, with the following additional con
|
||||
|
||||
1. Dify may be utilized commercially, including as a backend service for other applications or as an application development platform for enterprises. Should the conditions below be met, a commercial license must be obtained from the producer:
|
||||
|
||||
a. Multi-tenant SaaS service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
a. Multi-tenant service: Unless explicitly authorized by Dify in writing, you may not use the Dify source code to operate a multi-tenant environment.
|
||||
- Tenant Definition: Within the context of Dify, one tenant corresponds to one workspace. The workspace provides a separated area for each tenant's data and configurations.
|
||||
|
||||
b. LOGO and copyright information: In the process of using Dify's frontend components, you may not remove or modify the LOGO or copyright information in the Dify console or applications. This restriction is inapplicable to uses of Dify that do not involve its frontend components.
|
||||
|
||||
@@ -39,7 +39,7 @@ DB_DATABASE=dify
|
||||
|
||||
# Storage configuration
|
||||
# use for store upload files, private keys...
|
||||
# storage type: local, s3, azure-blob, google-storage
|
||||
# storage type: local, s3, azure-blob, google-storage, tencent-cos, huawei-obs, volcengine-tos
|
||||
STORAGE_TYPE=local
|
||||
STORAGE_LOCAL_PATH=storage
|
||||
S3_USE_AWS_MANAGED_IAM=false
|
||||
@@ -73,6 +73,12 @@ TENCENT_COS_SECRET_ID=your-secret-id
|
||||
TENCENT_COS_REGION=your-region
|
||||
TENCENT_COS_SCHEME=your-scheme
|
||||
|
||||
# Huawei OBS Storage Configuration
|
||||
HUAWEI_OBS_BUCKET_NAME=your-bucket-name
|
||||
HUAWEI_OBS_SECRET_KEY=your-secret-key
|
||||
HUAWEI_OBS_ACCESS_KEY=your-access-key
|
||||
HUAWEI_OBS_SERVER=your-server-url
|
||||
|
||||
# OCI Storage configuration
|
||||
OCI_ENDPOINT=your-endpoint
|
||||
OCI_BUCKET_NAME=your-bucket-name
|
||||
@@ -80,6 +86,13 @@ OCI_ACCESS_KEY=your-access-key
|
||||
OCI_SECRET_KEY=your-secret-key
|
||||
OCI_REGION=your-region
|
||||
|
||||
# Volcengine tos Storage configuration
|
||||
VOLCENGINE_TOS_ENDPOINT=your-endpoint
|
||||
VOLCENGINE_TOS_BUCKET_NAME=your-bucket-name
|
||||
VOLCENGINE_TOS_ACCESS_KEY=your-access-key
|
||||
VOLCENGINE_TOS_SECRET_KEY=your-secret-key
|
||||
VOLCENGINE_TOS_REGION=your-region
|
||||
|
||||
# CORS configuration
|
||||
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
|
||||
@@ -101,11 +114,10 @@ QDRANT_GRPC_ENABLED=false
|
||||
QDRANT_GRPC_PORT=6334
|
||||
|
||||
# Milvus configuration
|
||||
MILVUS_HOST=127.0.0.1
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_URI=http://127.0.0.1:19530
|
||||
MILVUS_TOKEN=
|
||||
MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# MyScale configuration
|
||||
MYSCALE_HOST=127.0.0.1
|
||||
|
||||
@@ -55,7 +55,7 @@ RUN apt-get update \
|
||||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.2-2 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1 expat=2.6.3-1 libldap-2.5-0=2.5.18+dfsg-3 perl=5.38.2-5 libsqlite3-0=3.46.0-1 \
|
||||
&& apt-get autoremove -y \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
|
||||
description="endpoint URL of code execution servcie",
|
||||
description="endpoint URL of code execution service",
|
||||
default="http://sandbox:8194",
|
||||
)
|
||||
|
||||
@@ -415,7 +415,7 @@ class MailConfig(BaseSettings):
|
||||
"""
|
||||
|
||||
MAIL_TYPE: Optional[str] = Field(
|
||||
description="Mail provider type name, default to None, availabile values are `smtp` and `resend`.",
|
||||
description="Mail provider type name, default to None, available values are `smtp` and `resend`.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt, computed_field
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from configs.middleware.cache.redis_config import RedisConfig
|
||||
@@ -9,8 +9,10 @@ from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorag
|
||||
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
|
||||
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
|
||||
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
|
||||
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
|
||||
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
|
||||
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
|
||||
from configs.middleware.vdb.chroma_config import ChromaConfig
|
||||
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
|
||||
@@ -157,6 +159,21 @@ class CeleryConfig(DatabaseConfig):
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_MASTER_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel master name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
@@ -184,6 +201,8 @@ class MiddlewareConfig(
|
||||
AzureBlobStorageConfig,
|
||||
GoogleCloudStorageConfig,
|
||||
TencentCloudCOSStorageConfig,
|
||||
HuaweiCloudOBSStorageConfig,
|
||||
VolcengineTOSStorageConfig,
|
||||
S3StorageConfig,
|
||||
OCIStorageConfig,
|
||||
# configs of vdb and vdb providers
|
||||
|
||||
32
api/configs/middleware/cache/redis_config.py
vendored
32
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, NonNegativeInt, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -38,3 +38,33 @@ class RedisConfig(BaseSettings):
|
||||
description="whether to use SSL for Redis connection",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_USE_SENTINEL: Optional[bool] = Field(
|
||||
description="Whether to use Redis Sentinel mode",
|
||||
default=False,
|
||||
)
|
||||
|
||||
REDIS_SENTINELS: Optional[str] = Field(
|
||||
description="Redis Sentinel nodes",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SERVICE_NAME: Optional[str] = Field(
|
||||
description="Redis Sentinel service name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_USERNAME: Optional[str] = Field(
|
||||
description="Redis Sentinel username",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_PASSWORD: Optional[str] = Field(
|
||||
description="Redis Sentinel password",
|
||||
default=None,
|
||||
)
|
||||
|
||||
REDIS_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
|
||||
description="Redis Sentinel socket timeout",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
29
api/configs/middleware/storage/huawei_obs_storage_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
default=None,
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
default=None,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, PositiveInt
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ class MilvusConfig(BaseSettings):
|
||||
Milvus configs
|
||||
"""
|
||||
|
||||
MILVUS_HOST: Optional[str] = Field(
|
||||
description="Milvus host",
|
||||
default=None,
|
||||
MILVUS_URI: Optional[str] = Field(
|
||||
description="Milvus uri",
|
||||
default="http://127.0.0.1:19530",
|
||||
)
|
||||
|
||||
MILVUS_PORT: PositiveInt = Field(
|
||||
description="Milvus RestFul API port",
|
||||
default=9091,
|
||||
MILVUS_TOKEN: Optional[str] = Field(
|
||||
description="Milvus token",
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_USER: Optional[str] = Field(
|
||||
@@ -29,11 +29,6 @@ class MilvusConfig(BaseSettings):
|
||||
default=None,
|
||||
)
|
||||
|
||||
MILVUS_SECURE: bool = Field(
|
||||
description="whether to use SSL connection for Milvus",
|
||||
default=False,
|
||||
)
|
||||
|
||||
MILVUS_DATABASE: str = Field(
|
||||
description="Milvus database, default to `default`",
|
||||
default="default",
|
||||
|
||||
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||
|
||||
CURRENT_VERSION: str = Field(
|
||||
description="Dify version",
|
||||
default="0.7.3",
|
||||
default="0.8.0",
|
||||
)
|
||||
|
||||
COMMIT_SHA: str = Field(
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -173,18 +173,21 @@ class ChatConversationApi(Resource):
|
||||
|
||||
if args["keyword"]:
|
||||
keyword_filter = "%{}%".format(args["keyword"])
|
||||
message_subquery = (
|
||||
db.session.query(Message.conversation_id)
|
||||
.filter(or_(Message.query.ilike(keyword_filter), Message.answer.ilike(keyword_filter)))
|
||||
.subquery()
|
||||
)
|
||||
query = query.join(subquery, subquery.c.conversation_id == Conversation.id).filter(
|
||||
or_(
|
||||
Conversation.id.in_(message_subquery),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
),
|
||||
query = (
|
||||
query.join(
|
||||
Message,
|
||||
Message.conversation_id == Conversation.id,
|
||||
)
|
||||
.join(subquery, subquery.c.conversation_id == Conversation.id)
|
||||
.filter(
|
||||
or_(
|
||||
Message.query.ilike(keyword_filter),
|
||||
Message.answer.ilike(keyword_filter),
|
||||
Conversation.name.ilike(keyword_filter),
|
||||
Conversation.introduction.ilike(keyword_filter),
|
||||
subquery.c.from_end_user_session_id.ilike(keyword_filter),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
account = current_user
|
||||
@@ -198,7 +201,11 @@ class ChatConversationApi(Resource):
|
||||
start_datetime_timezone = timezone.localize(start_datetime)
|
||||
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at >= start_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at >= start_datetime_utc)
|
||||
|
||||
if args["end"]:
|
||||
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
|
||||
@@ -207,7 +214,11 @@ class ChatConversationApi(Resource):
|
||||
end_datetime_timezone = timezone.localize(end_datetime)
|
||||
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
|
||||
|
||||
query = query.where(Conversation.created_at < end_datetime_utc)
|
||||
match args["sort_by"]:
|
||||
case "updated_at" | "-updated_at":
|
||||
query = query.where(Conversation.updated_at <= end_datetime_utc)
|
||||
case "created_at" | "-created_at" | _:
|
||||
query = query.where(Conversation.created_at <= end_datetime_utc)
|
||||
|
||||
if args["annotation_status"] == "annotated":
|
||||
query = query.options(joinedload(Conversation.message_annotations)).join(
|
||||
|
||||
@@ -18,7 +18,7 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.datasource.vdb.vector_type import VectorType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.retrieval.retrival_methods import RetrievalMethod
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import related_app_list
|
||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||
|
||||
@@ -302,6 +302,8 @@ class DatasetInitApi(Resource):
|
||||
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
|
||||
)
|
||||
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
|
||||
parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
@@ -309,6 +311,8 @@ class DatasetInitApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
if args["indexing_technique"] == "high_quality":
|
||||
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
|
||||
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
model_manager.get_default_model_instance(
|
||||
|
||||
@@ -13,7 +13,7 @@ from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ class SegmentApi(DatasetApiResource):
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
raise NotFound("Document is not completed.")
|
||||
if not document.enabled:
|
||||
raise NotFound("Document is disabled.")
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
try:
|
||||
@@ -63,7 +67,7 @@ class SegmentApi(DatasetApiResource):
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segemtns is required"}, 400
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
def get(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
|
||||
@@ -1 +1 @@
|
||||
import core.moderation.base
|
||||
import core.moderation.base
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -45,22 +46,25 @@ from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgentRunner(AppRunner):
|
||||
def __init__(self, tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
conversation: Conversation,
|
||||
app_config: AgentChatAppConfig,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
@@ -88,9 +92,7 @@ class BaseAgentRunner(AppRunner):
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.memory = memory
|
||||
self.history_prompt_messages = self.organize_agent_history(
|
||||
prompt_messages=prompt_messages or []
|
||||
)
|
||||
self.history_prompt_messages = self.organize_agent_history(prompt_messages=prompt_messages or [])
|
||||
self.variables_pool = variables_pool
|
||||
self.db_variables_pool = db_variables
|
||||
self.model_instance = model_instance
|
||||
@@ -111,12 +113,16 @@ class BaseAgentRunner(AppRunner):
|
||||
retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,
|
||||
return_resource=app_config.additional_features.show_retrieve_source,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
hit_callback=hit_callback
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
).count()
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
# check if model supports stream tool call
|
||||
@@ -135,25 +141,26 @@ class BaseAgentRunner(AppRunner):
|
||||
self.query = None
|
||||
self._current_thoughts: list[PromptMessage] = []
|
||||
|
||||
def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \
|
||||
-> AgentChatAppGenerateEntity:
|
||||
def _repack_app_generate_entity(
|
||||
self, app_generate_entity: AgentChatAppGenerateEntity
|
||||
) -> AgentChatAppGenerateEntity:
|
||||
"""
|
||||
Repack app generate entity
|
||||
"""
|
||||
if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ""
|
||||
|
||||
return app_generate_entity
|
||||
|
||||
|
||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_config.app_id,
|
||||
agent_tool=tool,
|
||||
invoke_from=self.application_generate_entity.invoke_from
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@@ -164,7 +171,7 @@ class BaseAgentRunner(AppRunner):
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@@ -177,19 +184,19 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
|
||||
def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:
|
||||
"""
|
||||
convert dataset retriever tool to prompt message tool
|
||||
@@ -201,24 +208,24 @@ class BaseAgentRunner(AppRunner):
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
for parameter in tool.get_runtime_parameters():
|
||||
parameter_type = 'string'
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
parameter_type = "string"
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:
|
||||
|
||||
def _init_prompt_tools(self) -> tuple[Mapping[str, Tool], Sequence[PromptMessageTool]]:
|
||||
"""
|
||||
Init tools
|
||||
"""
|
||||
@@ -261,51 +268,51 @@ class BaseAgentRunner(AppRunner):
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options]
|
||||
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.name not in prompt_tool.parameters["required"]:
|
||||
prompt_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
|
||||
def create_agent_thought(
|
||||
self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Create agent thought
|
||||
"""
|
||||
thought = MessageAgentThought(
|
||||
message_id=message_id,
|
||||
message_chain_id=None,
|
||||
thought='',
|
||||
thought="",
|
||||
tool=tool_name,
|
||||
tool_labels_str='{}',
|
||||
tool_meta_str='{}',
|
||||
tool_labels_str="{}",
|
||||
tool_meta_str="{}",
|
||||
tool_input=tool_input,
|
||||
message=message,
|
||||
message_token=0,
|
||||
message_unit_price=0,
|
||||
message_price_unit=0,
|
||||
message_files=json.dumps(messages_ids) if messages_ids else '',
|
||||
answer='',
|
||||
observation='',
|
||||
message_files=json.dumps(messages_ids) if messages_ids else "",
|
||||
answer="",
|
||||
observation="",
|
||||
answer_token=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
tokens=0,
|
||||
total_price=0,
|
||||
position=self.agent_thought_count + 1,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
latency=0,
|
||||
created_by_role='account',
|
||||
created_by_role="account",
|
||||
created_by=self.user_id,
|
||||
)
|
||||
|
||||
@@ -318,22 +325,22 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
return thought
|
||||
|
||||
def save_agent_thought(self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None) -> MessageAgentThought:
|
||||
def save_agent_thought(
|
||||
self,
|
||||
agent_thought: MessageAgentThought,
|
||||
tool_name: str,
|
||||
tool_input: Union[str, dict],
|
||||
thought: str,
|
||||
observation: Union[str, dict],
|
||||
tool_invoke_meta: Union[str, dict],
|
||||
answer: str,
|
||||
messages_ids: list[str],
|
||||
llm_usage: LLMUsage = None,
|
||||
) -> MessageAgentThought:
|
||||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(
|
||||
MessageAgentThought.id == agent_thought.id
|
||||
).first()
|
||||
agent_thought = db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
|
||||
|
||||
if thought is not None:
|
||||
agent_thought.thought = thought
|
||||
@@ -356,7 +363,7 @@ class BaseAgentRunner(AppRunner):
|
||||
observation = json.dumps(observation, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
observation = json.dumps(observation)
|
||||
|
||||
|
||||
agent_thought.observation = observation
|
||||
|
||||
if answer is not None:
|
||||
@@ -364,7 +371,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
if messages_ids is not None and len(messages_ids) > 0:
|
||||
agent_thought.message_files = json.dumps(messages_ids)
|
||||
|
||||
|
||||
if llm_usage:
|
||||
agent_thought.message_token = llm_usage.prompt_tokens
|
||||
agent_thought.message_price_unit = llm_usage.prompt_price_unit
|
||||
@@ -377,7 +384,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
# check if tool labels is not empty
|
||||
labels = agent_thought.tool_labels or {}
|
||||
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||
tools = agent_thought.tool.split(";") if agent_thought.tool else []
|
||||
for tool in tools:
|
||||
if not tool:
|
||||
continue
|
||||
@@ -386,7 +393,7 @@ class BaseAgentRunner(AppRunner):
|
||||
if tool_label:
|
||||
labels[tool] = tool_label.to_dict()
|
||||
else:
|
||||
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||
labels[tool] = {"en_US": tool, "zh_Hans": tool}
|
||||
|
||||
agent_thought.tool_labels_str = json.dumps(labels)
|
||||
|
||||
@@ -401,14 +408,18 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
|
||||
def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):
|
||||
"""
|
||||
convert tool variables to db variables
|
||||
"""
|
||||
db_variables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
).first()
|
||||
db_variables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))
|
||||
@@ -425,9 +436,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if isinstance(prompt_message, SystemPromptMessage):
|
||||
result.append(prompt_message)
|
||||
|
||||
messages: list[Message] = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
).order_by(Message.created_at.asc()).all()
|
||||
messages: list[Message] = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == self.message.conversation_id,
|
||||
)
|
||||
.order_by(Message.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if message.id == self.message.id:
|
||||
@@ -439,13 +455,13 @@ class BaseAgentRunner(AppRunner):
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tools = tools.split(";")
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
tool_inputs = {tool: {} for tool in tools}
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
@@ -454,27 +470,33 @@ class BaseAgentRunner(AppRunner):
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
)
|
||||
tool_call_response.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
result.extend(
|
||||
[
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response,
|
||||
]
|
||||
)
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
@@ -496,10 +518,7 @@ class BaseAgentRunner(AppRunner):
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files,
|
||||
file_extra_config
|
||||
)
|
||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
||||
@@ -25,17 +25,19 @@ from models.model import Message
|
||||
|
||||
class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
_is_first_iteration = True
|
||||
_ignore_observation_providers = ['wenxin']
|
||||
_ignore_observation_providers = ["wenxin"]
|
||||
_historic_prompt_messages: list[PromptMessage] = None
|
||||
_agent_scratchpad: list[AgentScratchpadUnit] = None
|
||||
_instruction: str = None
|
||||
_query: str = None
|
||||
_prompt_messages_tools: list[PromptMessage] = None
|
||||
|
||||
def run(self, message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
def run(
|
||||
self,
|
||||
message: Message,
|
||||
query: str,
|
||||
inputs: dict[str, str],
|
||||
) -> Union[Generator, LLMResult]:
|
||||
"""
|
||||
Run Cot agent application
|
||||
"""
|
||||
@@ -46,17 +48,16 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
# check model mode
|
||||
if 'Observation' not in app_generate_entity.model_conf.stop:
|
||||
if "Observation" not in app_generate_entity.model_conf.stop:
|
||||
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
|
||||
app_generate_entity.model_conf.stop.append('Observation')
|
||||
app_generate_entity.model_conf.stop.append("Observation")
|
||||
|
||||
app_config = self.app_config
|
||||
|
||||
# init instruction
|
||||
inputs = inputs or {}
|
||||
instruction = app_config.prompt_template.simple_prompt_template
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(
|
||||
instruction, inputs)
|
||||
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
|
||||
iteration_step = 1
|
||||
max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1
|
||||
@@ -65,16 +66,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
|
||||
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@@ -94,17 +93,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
message_file_ids = []
|
||||
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
if iteration_step > 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
@@ -125,21 +120,20 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
raise ValueError("failed to invoke llm")
|
||||
|
||||
usage_dict = {}
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(
|
||||
chunks, usage_dict)
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response='',
|
||||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
if iteration_step == 1:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
@@ -154,61 +148,51 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
yield LLMResultChunk(
|
||||
model=self.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=chunk
|
||||
),
|
||||
usage=None
|
||||
)
|
||||
system_fingerprint="",
|
||||
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
|
||||
)
|
||||
|
||||
scratchpad.thought = scratchpad.thought.strip(
|
||||
) or 'I am thinking about how to help you'
|
||||
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
|
||||
self._agent_scratchpad.append(scratchpad)
|
||||
|
||||
# get llm usage
|
||||
if 'usage' in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict['usage'])
|
||||
if "usage" in usage_dict:
|
||||
increase_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
usage_dict['usage'] = LLMUsage.empty_usage()
|
||||
usage_dict["usage"] = LLMUsage.empty_usage()
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else '',
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input
|
||||
} if scratchpad.action else {},
|
||||
tool_name=scratchpad.action.action_name if scratchpad.action else "",
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
|
||||
tool_invoke_meta={},
|
||||
thought=scratchpad.thought,
|
||||
observation='',
|
||||
observation="",
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=[],
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
if not scratchpad.is_final():
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
if not scratchpad.action:
|
||||
# failed to extract action, return final answer directly
|
||||
final_answer = ''
|
||||
final_answer = ""
|
||||
else:
|
||||
if scratchpad.action.action_name.lower() == "final answer":
|
||||
# action is final answer, return final answer directly
|
||||
try:
|
||||
if isinstance(scratchpad.action.action_input, dict):
|
||||
final_answer = json.dumps(
|
||||
scratchpad.action.action_input)
|
||||
final_answer = json.dumps(scratchpad.action.action_input)
|
||||
elif isinstance(scratchpad.action.action_input, str):
|
||||
final_answer = scratchpad.action.action_input
|
||||
else:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
except json.JSONDecodeError:
|
||||
final_answer = f'{scratchpad.action.action_input}'
|
||||
final_answer = f"{scratchpad.action.action_input}"
|
||||
else:
|
||||
function_call_state = True
|
||||
# action is tool call, invoke tool
|
||||
@@ -224,21 +208,18 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name=scratchpad.action.action_name,
|
||||
tool_input={
|
||||
scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
|
||||
thought=scratchpad.thought,
|
||||
observation={
|
||||
scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={
|
||||
scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
observation={scratchpad.action.action_name: tool_invoke_response},
|
||||
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
|
||||
answer=scratchpad.agent_response,
|
||||
messages_ids=message_file_ids,
|
||||
llm_usage=usage_dict['usage']
|
||||
llm_usage=usage_dict["usage"],
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
# update prompt tool message
|
||||
for prompt_tool in self._prompt_messages_tools:
|
||||
@@ -250,44 +231,45 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
),
|
||||
usage=llm_usage['usage']
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
),
|
||||
system_fingerprint=''
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
tool_name='',
|
||||
tool_name="",
|
||||
tool_input={},
|
||||
tool_invoke_meta={},
|
||||
thought=final_answer,
|
||||
observation={},
|
||||
answer=final_answer,
|
||||
messages_ids=[]
|
||||
messages_ids=[],
|
||||
)
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_action(self, action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
def _handle_invoke_action(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
tool_instances: dict[str, Tool],
|
||||
message_file_ids: list[str],
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> tuple[str, ToolInvokeMeta]:
|
||||
"""
|
||||
handle invoke action
|
||||
:param action: action
|
||||
@@ -326,13 +308,12 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
# publish files
|
||||
for message_file_id, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(
|
||||
tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
@@ -342,10 +323,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
convert dict to action
|
||||
"""
|
||||
return AgentScratchpadUnit.Action(
|
||||
action_name=action['action'],
|
||||
action_input=action['action_input']
|
||||
)
|
||||
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
|
||||
|
||||
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str:
|
||||
"""
|
||||
@@ -353,7 +331,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
instruction = instruction.replace(f'{{{{{key}}}}}', str(value))
|
||||
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
@@ -370,14 +348,14 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
@abstractmethod
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
organize prompt messages
|
||||
organize prompt messages
|
||||
"""
|
||||
|
||||
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
|
||||
"""
|
||||
format assistant message
|
||||
format assistant message
|
||||
"""
|
||||
message = ''
|
||||
message = ""
|
||||
for scratchpad in agent_scratchpad:
|
||||
if scratchpad.is_final():
|
||||
message += f"Final Answer: {scratchpad.agent_response}"
|
||||
@@ -390,9 +368,11 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
|
||||
return message
|
||||
|
||||
def _organize_historic_prompt_messages(self, current_session_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_historic_prompt_messages(
|
||||
self, current_session_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
organize historic prompt messages
|
||||
organize historic prompt messages
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
scratchpads: list[AgentScratchpadUnit] = []
|
||||
@@ -403,8 +383,8 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
if not current_scratchpad:
|
||||
current_scratchpad = AgentScratchpadUnit(
|
||||
agent_response=message.content,
|
||||
thought=message.content or 'I am thinking about how to help you',
|
||||
action_str='',
|
||||
thought=message.content or "I am thinking about how to help you",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
@@ -413,12 +393,9 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
try:
|
||||
current_scratchpad.action = AgentScratchpadUnit.Action(
|
||||
action_name=message.tool_calls[0].function.name,
|
||||
action_input=json.loads(
|
||||
message.tool_calls[0].function.arguments)
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(
|
||||
current_scratchpad.action.to_dict()
|
||||
action_input=json.loads(message.tool_calls[0].function.arguments),
|
||||
)
|
||||
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
|
||||
except:
|
||||
pass
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
@@ -426,23 +403,19 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
current_scratchpad.observation = message.content
|
||||
elif isinstance(message, UserPromptMessage):
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
scratchpads = []
|
||||
current_scratchpad = None
|
||||
|
||||
result.append(message)
|
||||
|
||||
if scratchpads:
|
||||
result.append(AssistantPromptMessage(
|
||||
content=self._format_assistant_message(scratchpads)
|
||||
))
|
||||
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
|
||||
|
||||
historic_prompts = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=current_session_messages or [],
|
||||
history_messages=result,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
return historic_prompts
|
||||
|
||||
@@ -19,14 +19,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt \
|
||||
.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return SystemPromptMessage(content=system_prompt)
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -43,7 +44,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
def _organize_prompt_messages(self) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize
|
||||
Organize
|
||||
"""
|
||||
# organize system prompt
|
||||
system_message = self._organize_system_prompt()
|
||||
@@ -53,7 +54,7 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
if not agent_scratchpad:
|
||||
assistant_messages = []
|
||||
else:
|
||||
assistant_message = AssistantPromptMessage(content='')
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_message.content += f"Final Answer: {unit.agent_response}"
|
||||
@@ -71,18 +72,15 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||
|
||||
if assistant_messages:
|
||||
# organize historic prompt messages
|
||||
historic_messages = self._organize_historic_prompt_messages([
|
||||
system_message,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
])
|
||||
historic_messages = self._organize_historic_prompt_messages(
|
||||
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
|
||||
)
|
||||
messages = [
|
||||
system_message,
|
||||
*historic_messages,
|
||||
*query_messages,
|
||||
*assistant_messages,
|
||||
UserPromptMessage(content='continue')
|
||||
UserPromptMessage(content="continue"),
|
||||
]
|
||||
else:
|
||||
# organize historic prompt messages
|
||||
|
||||
@@ -13,10 +13,12 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
prompt_entity = self.app_config.agent.prompt
|
||||
first_prompt = prompt_entity.first_prompt
|
||||
|
||||
system_prompt = first_prompt.replace("{{instruction}}", self._instruction) \
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools))) \
|
||||
.replace("{{tool_names}}", ', '.join([tool.name for tool in self._prompt_messages_tools]))
|
||||
|
||||
system_prompt = (
|
||||
first_prompt.replace("{{instruction}}", self._instruction)
|
||||
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
|
||||
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
|
||||
)
|
||||
|
||||
return system_prompt
|
||||
|
||||
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] = None) -> str:
|
||||
@@ -46,7 +48,7 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
|
||||
# organize current assistant messages
|
||||
agent_scratchpad = self._agent_scratchpad
|
||||
assistant_prompt = ''
|
||||
assistant_prompt = ""
|
||||
for unit in agent_scratchpad:
|
||||
if unit.is_final():
|
||||
assistant_prompt += f"Final Answer: {unit.agent_response}"
|
||||
@@ -61,9 +63,10 @@ class CotCompletionAgentRunner(CotAgentRunner):
|
||||
query_prompt = f"Question: {self._query}"
|
||||
|
||||
# join all messages
|
||||
prompt = system_prompt \
|
||||
.replace("{{historic_messages}}", historic_prompt) \
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt) \
|
||||
prompt = (
|
||||
system_prompt.replace("{{historic_messages}}", historic_prompt)
|
||||
.replace("{{agent_scratchpad}}", assistant_prompt)
|
||||
.replace("{{query}}", query_prompt)
|
||||
)
|
||||
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
return [UserPromptMessage(content=prompt)]
|
||||
|
||||
@@ -8,6 +8,7 @@ class AgentToolEntity(BaseModel):
|
||||
"""
|
||||
Agent Tool Entity.
|
||||
"""
|
||||
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
@@ -18,6 +19,7 @@ class AgentPromptEntity(BaseModel):
|
||||
"""
|
||||
Agent Prompt Entity.
|
||||
"""
|
||||
|
||||
first_prompt: str
|
||||
next_iteration: str
|
||||
|
||||
@@ -31,6 +33,7 @@ class AgentScratchpadUnit(BaseModel):
|
||||
"""
|
||||
Action Entity.
|
||||
"""
|
||||
|
||||
action_name: str
|
||||
action_input: Union[dict, str]
|
||||
|
||||
@@ -39,8 +42,8 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Convert to dictionary.
|
||||
"""
|
||||
return {
|
||||
'action': self.action_name,
|
||||
'action_input': self.action_input,
|
||||
"action": self.action_name,
|
||||
"action_input": self.action_input,
|
||||
}
|
||||
|
||||
agent_response: Optional[str] = None
|
||||
@@ -54,10 +57,10 @@ class AgentScratchpadUnit(BaseModel):
|
||||
Check if the scratchpad unit is final.
|
||||
"""
|
||||
return self.action is None or (
|
||||
'final' in self.action.action_name.lower() and
|
||||
'answer' in self.action.action_name.lower()
|
||||
"final" in self.action.action_name.lower() and "answer" in self.action.action_name.lower()
|
||||
)
|
||||
|
||||
|
||||
class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Entity.
|
||||
@@ -67,8 +70,9 @@ class AgentEntity(BaseModel):
|
||||
"""
|
||||
Agent Strategy.
|
||||
"""
|
||||
CHAIN_OF_THOUGHT = 'chain-of-thought'
|
||||
FUNCTION_CALLING = 'function-calling'
|
||||
|
||||
CHAIN_OF_THOUGHT = "chain-of-thought"
|
||||
FUNCTION_CALLING = "function-calling"
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
@@ -24,11 +24,9 @@ from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
def run(self,
|
||||
message: Message, query: str, **kwargs: Any
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run FunctionCall agent application
|
||||
"""
|
||||
@@ -45,19 +43,17 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
# continue to run until there is not any tool call
|
||||
function_call_state = True
|
||||
llm_usage = {
|
||||
'usage': None
|
||||
}
|
||||
final_answer = ''
|
||||
llm_usage = {"usage": None}
|
||||
final_answer = ""
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = app_generate_entity.trace_manager
|
||||
|
||||
|
||||
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage):
|
||||
if not final_llm_usage_dict['usage']:
|
||||
final_llm_usage_dict['usage'] = usage
|
||||
if not final_llm_usage_dict["usage"]:
|
||||
final_llm_usage_dict["usage"] = usage
|
||||
else:
|
||||
llm_usage = final_llm_usage_dict['usage']
|
||||
llm_usage = final_llm_usage_dict["usage"]
|
||||
llm_usage.prompt_tokens += usage.prompt_tokens
|
||||
llm_usage.completion_tokens += usage.completion_tokens
|
||||
llm_usage.prompt_price += usage.prompt_price
|
||||
@@ -75,11 +71,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
message_file_ids = []
|
||||
agent_thought = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message='',
|
||||
tool_name='',
|
||||
tool_input='',
|
||||
messages_ids=message_file_ids
|
||||
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
# recalc llm max tokens
|
||||
@@ -99,11 +91,11 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
|
||||
# save full response
|
||||
response = ''
|
||||
response = ""
|
||||
|
||||
# save tool call names and inputs
|
||||
tool_call_names = ''
|
||||
tool_call_inputs = ''
|
||||
tool_call_names = ""
|
||||
tool_call_inputs = ""
|
||||
|
||||
current_llm_usage = None
|
||||
|
||||
@@ -111,24 +103,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
is_first_chunk = True
|
||||
for chunk in chunks:
|
||||
if is_first_chunk:
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
is_first_chunk = False
|
||||
# check if there is any tool call
|
||||
if self.check_tool_calls(chunk):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_tool_calls(chunk))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, list):
|
||||
@@ -148,16 +138,14 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if self.check_blocking_tool_calls(result):
|
||||
function_call_state = True
|
||||
tool_calls.extend(self.extract_blocking_tool_calls(result))
|
||||
tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls])
|
||||
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
|
||||
try:
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
}, ensure_ascii=False)
|
||||
tool_call_inputs = json.dumps(
|
||||
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
# ensure ascii to avoid encoding error
|
||||
tool_call_inputs = json.dumps({
|
||||
tool_call[1]: tool_call[2] for tool_call in tool_calls
|
||||
})
|
||||
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
|
||||
|
||||
if result.usage:
|
||||
increase_usage(llm_usage, result.usage)
|
||||
@@ -171,12 +159,12 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
response += result.message.content
|
||||
|
||||
if not result.message.content:
|
||||
result.message.content = ''
|
||||
result.message.content = ""
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
@@ -185,32 +173,29 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content='',
|
||||
tool_calls=[]
|
||||
)
|
||||
assistant_message = AssistantPromptMessage(content="", tool_calls=[])
|
||||
if tool_calls:
|
||||
assistant_message.tool_calls=[
|
||||
assistant_message.tool_calls = [
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call[0],
|
||||
type='function',
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call[1],
|
||||
arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
)
|
||||
) for tool_call in tool_calls
|
||||
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
|
||||
),
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
else:
|
||||
assistant_message.content = response
|
||||
|
||||
|
||||
self._current_thoughts.append(assistant_message)
|
||||
|
||||
# save thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=tool_call_names,
|
||||
tool_input=tool_call_inputs,
|
||||
thought=response,
|
||||
@@ -218,13 +203,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
observation=None,
|
||||
answer=response,
|
||||
messages_ids=[],
|
||||
llm_usage=current_llm_usage
|
||||
llm_usage=current_llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
final_answer += response + '\n'
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
final_answer += response + "\n"
|
||||
|
||||
# call tools
|
||||
tool_responses = []
|
||||
@@ -235,7 +220,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}",
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict()
|
||||
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
|
||||
}
|
||||
else:
|
||||
# invoke tool
|
||||
@@ -255,50 +240,49 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file_id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file_id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file_id)
|
||||
|
||||
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": tool_invoke_response,
|
||||
"meta": tool_invoke_meta.to_dict()
|
||||
"meta": tool_invoke_meta.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
if tool_response['tool_response'] is not None:
|
||||
if tool_response["tool_response"] is not None:
|
||||
self._current_thoughts.append(
|
||||
ToolPromptMessage(
|
||||
content=tool_response['tool_response'],
|
||||
content=tool_response["tool_response"],
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_call_name,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_responses) > 0:
|
||||
# save agent thought
|
||||
self.save_agent_thought(
|
||||
agent_thought=agent_thought,
|
||||
agent_thought=agent_thought,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
thought=None,
|
||||
tool_invoke_meta={
|
||||
tool_response['tool_call_name']: tool_response['meta']
|
||||
for tool_response in tool_responses
|
||||
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
|
||||
},
|
||||
observation={
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
tool_response["tool_call_name"]: tool_response["tool_response"]
|
||||
for tool_response in tool_responses
|
||||
},
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=agent_thought.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self.queue_manager.publish(QueueAgentThoughtEvent(
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
# update prompt tool
|
||||
for prompt_tool in prompt_messages_tools:
|
||||
@@ -308,15 +292,18 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
|
||||
self.update_db_variables(self.variables_pool, self.db_variables_pool)
|
||||
# publish end event
|
||||
self.queue_manager.publish(QueueMessageEndEvent(llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(
|
||||
content=final_answer
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] if llm_usage["usage"] else LLMUsage.empty_usage(),
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(),
|
||||
system_fingerprint=''
|
||||
)), PublishFrom.APPLICATION_MANAGER)
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
|
||||
"""
|
||||
@@ -325,7 +312,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
if llm_result_chunk.delta.message.tool_calls:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
|
||||
"""
|
||||
Check if there is any blocking tool call in llm result
|
||||
@@ -334,7 +321,9 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return True
|
||||
return False
|
||||
|
||||
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
def extract_tool_calls(
|
||||
self, llm_result_chunk: LLMResultChunk
|
||||
) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract tool calls from llm result chunk
|
||||
|
||||
@@ -344,17 +333,19 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result_chunk.delta.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]:
|
||||
"""
|
||||
Extract blocking tool calls from llm result
|
||||
@@ -365,18 +356,22 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_calls = []
|
||||
for prompt_message in llm_result.message.tool_calls:
|
||||
args = {}
|
||||
if prompt_message.function.arguments != '':
|
||||
if prompt_message.function.arguments != "":
|
||||
args = json.loads(prompt_message.function.arguments)
|
||||
|
||||
tool_calls.append((
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
))
|
||||
tool_calls.append(
|
||||
(
|
||||
prompt_message.id,
|
||||
prompt_message.function.name,
|
||||
args,
|
||||
)
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _init_system_message(
|
||||
self, prompt_template: str, prompt_messages: list[PromptMessage] = None
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
@@ -384,13 +379,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
return [
|
||||
SystemPromptMessage(content=prompt_template),
|
||||
]
|
||||
|
||||
|
||||
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
@@ -404,7 +399,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
@@ -415,17 +410,21 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = '\n'.join([
|
||||
content.data if content.type == PromptMessageContentType.TEXT else
|
||||
'[image]' if content.type == PromptMessageContentType.IMAGE else
|
||||
'[file]'
|
||||
for content in prompt_message.content
|
||||
])
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ''
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query, [])
|
||||
|
||||
@@ -433,14 +432,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [
|
||||
*self.history_prompt_messages,
|
||||
*query_prompt_messages,
|
||||
*self._current_thoughts
|
||||
]
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
|
||||
@@ -9,8 +9,9 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
|
||||
class CotAgentOutputParser:
|
||||
@classmethod
|
||||
def handle_react_stream_output(cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict) -> \
|
||||
Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def handle_react_stream_output(
|
||||
cls, llm_response: Generator[LLMResultChunk, None, None], usage_dict: dict
|
||||
) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]:
|
||||
def parse_action(json_str):
|
||||
try:
|
||||
action = json.loads(json_str)
|
||||
@@ -22,7 +23,7 @@ class CotAgentOutputParser:
|
||||
action = action[0]
|
||||
|
||||
for key, value in action.items():
|
||||
if 'input' in key.lower():
|
||||
if "input" in key.lower():
|
||||
action_input = value
|
||||
else:
|
||||
action_name = value
|
||||
@@ -33,37 +34,37 @@ class CotAgentOutputParser:
|
||||
action_input=action_input,
|
||||
)
|
||||
else:
|
||||
return json_str or ''
|
||||
return json_str or ""
|
||||
except:
|
||||
return json_str or ''
|
||||
|
||||
return json_str or ""
|
||||
|
||||
def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]:
|
||||
code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL)
|
||||
code_blocks = re.findall(r"```(.*?)```", code_block, re.DOTALL)
|
||||
if not code_blocks:
|
||||
return
|
||||
for block in code_blocks:
|
||||
json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE)
|
||||
json_text = re.sub(r"^[a-zA-Z]+\n", "", block.strip(), flags=re.MULTILINE)
|
||||
yield parse_action(json_text)
|
||||
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
code_block_delimiter_count = 0
|
||||
in_code_block = False
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
got_json = False
|
||||
|
||||
action_cache = ''
|
||||
action_str = 'action:'
|
||||
action_cache = ""
|
||||
action_str = "action:"
|
||||
action_idx = 0
|
||||
|
||||
thought_cache = ''
|
||||
thought_str = 'thought:'
|
||||
thought_cache = ""
|
||||
thought_str = "thought:"
|
||||
thought_idx = 0
|
||||
|
||||
for response in llm_response:
|
||||
if response.delta.usage:
|
||||
usage_dict['usage'] = response.delta.usage
|
||||
usage_dict["usage"] = response.delta.usage
|
||||
response = response.delta.message.content
|
||||
if not isinstance(response, str):
|
||||
continue
|
||||
@@ -72,24 +73,24 @@ class CotAgentOutputParser:
|
||||
index = 0
|
||||
while index < len(response):
|
||||
steps = 1
|
||||
delta = response[index:index+steps]
|
||||
last_character = response[index-1] if index > 0 else ''
|
||||
delta = response[index : index + steps]
|
||||
last_character = response[index - 1] if index > 0 else ""
|
||||
|
||||
if delta == '`':
|
||||
if delta == "`":
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count += 1
|
||||
else:
|
||||
if not in_code_block:
|
||||
if code_block_delimiter_count > 0:
|
||||
yield code_block_cache
|
||||
code_block_cache = ''
|
||||
code_block_cache = ""
|
||||
else:
|
||||
code_block_cache += delta
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
if delta.lower() == action_str[action_idx] and action_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@@ -97,7 +98,7 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@@ -105,18 +106,18 @@ class CotAgentOutputParser:
|
||||
action_cache += delta
|
||||
action_idx += 1
|
||||
if action_idx == len(action_str):
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if action_cache:
|
||||
yield action_cache
|
||||
action_cache = ''
|
||||
action_cache = ""
|
||||
action_idx = 0
|
||||
|
||||
|
||||
if delta.lower() == thought_str[thought_idx] and thought_idx == 0:
|
||||
if last_character not in ['\n', ' ', '']:
|
||||
if last_character not in ["\n", " ", ""]:
|
||||
index += steps
|
||||
yield delta
|
||||
continue
|
||||
@@ -124,7 +125,7 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
@@ -132,31 +133,31 @@ class CotAgentOutputParser:
|
||||
thought_cache += delta
|
||||
thought_idx += 1
|
||||
if thought_idx == len(thought_str):
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
index += steps
|
||||
continue
|
||||
else:
|
||||
if thought_cache:
|
||||
yield thought_cache
|
||||
thought_cache = ''
|
||||
thought_cache = ""
|
||||
thought_idx = 0
|
||||
|
||||
if code_block_delimiter_count == 3:
|
||||
if in_code_block:
|
||||
yield from extra_json_from_code_block(code_block_cache)
|
||||
code_block_cache = ''
|
||||
|
||||
code_block_cache = ""
|
||||
|
||||
in_code_block = not in_code_block
|
||||
code_block_delimiter_count = 0
|
||||
|
||||
if not in_code_block:
|
||||
# handle single json
|
||||
if delta == '{':
|
||||
if delta == "{":
|
||||
json_quote_count += 1
|
||||
in_json = True
|
||||
json_cache += delta
|
||||
elif delta == '}':
|
||||
elif delta == "}":
|
||||
json_cache += delta
|
||||
if json_quote_count > 0:
|
||||
json_quote_count -= 1
|
||||
@@ -172,12 +173,12 @@ class CotAgentOutputParser:
|
||||
if got_json:
|
||||
got_json = False
|
||||
yield parse_action(json_cache)
|
||||
json_cache = ''
|
||||
json_cache = ""
|
||||
json_quote_count = 0
|
||||
in_json = False
|
||||
|
||||
|
||||
if not in_code_block and not in_json:
|
||||
yield delta.replace('`', '')
|
||||
yield delta.replace("`", "")
|
||||
|
||||
index += steps
|
||||
|
||||
@@ -186,4 +187,3 @@ class CotAgentOutputParser:
|
||||
|
||||
if json_cache:
|
||||
yield parse_action(json_cache)
|
||||
|
||||
|
||||
@@ -91,14 +91,14 @@ Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use
|
||||
ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES = ""
|
||||
|
||||
REACT_PROMPT_TEMPLATES = {
|
||||
'english': {
|
||||
'chat': {
|
||||
'prompt': ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES
|
||||
"english": {
|
||||
"chat": {
|
||||
"prompt": ENGLISH_REACT_CHAT_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_CHAT_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
"completion": {
|
||||
"prompt": ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
"agent_scratchpad": ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES,
|
||||
},
|
||||
'completion': {
|
||||
'prompt': ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES,
|
||||
'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,34 +26,24 @@ class BaseAppConfigManager:
|
||||
config_dict = dict(config_dict.items())
|
||||
|
||||
additional_features = AppAdditionalFeatures()
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.file_upload = FileUploadConfigManager.convert(
|
||||
config=config_dict,
|
||||
is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT]
|
||||
)
|
||||
|
||||
additional_features.opening_statement, additional_features.suggested_questions = \
|
||||
OpeningStatementConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.opening_statement, additional_features.suggested_questions = (
|
||||
OpeningStatementConfigManager.convert(config=config_dict)
|
||||
)
|
||||
|
||||
additional_features.suggested_questions_after_answer = SuggestedQuestionsAfterAnswerConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.more_like_this = MoreLikeThisConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.speech_to_text = SpeechToTextConfigManager.convert(config=config_dict)
|
||||
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(
|
||||
config=config_dict
|
||||
)
|
||||
additional_features.text_to_speech = TextToSpeechConfigManager.convert(config=config_dict)
|
||||
|
||||
return additional_features
|
||||
|
||||
@@ -7,25 +7,24 @@ from core.moderation.factory import ModerationFactory
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> Optional[SensitiveWordAvoidanceEntity]:
|
||||
sensitive_word_avoidance_dict = config.get('sensitive_word_avoidance')
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
|
||||
if sensitive_word_avoidance_dict.get('enabled'):
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get('type'),
|
||||
config=sensitive_word_avoidance_dict.get('config'),
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id, config: dict, only_structure_validate: bool = False) \
|
||||
-> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["sensitive_word_avoidance"], dict):
|
||||
raise ValueError("sensitive_word_avoidance must be of dict type")
|
||||
@@ -41,10 +40,6 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
typ = config["sensitive_word_avoidance"]["type"]
|
||||
sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"]
|
||||
|
||||
ModerationFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=sensitive_word_avoidance_config
|
||||
)
|
||||
ModerationFactory.validate_config(name=typ, tenant_id=tenant_id, config=sensitive_word_avoidance_config)
|
||||
|
||||
return config, ["sensitive_word_avoidance"]
|
||||
|
||||
@@ -12,67 +12,70 @@ class AgentConfigManager:
|
||||
|
||||
:param config: model config args
|
||||
"""
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode']:
|
||||
if "agent_mode" in config and config["agent_mode"] and "enabled" in config["agent_mode"]:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
agent_strategy = agent_dict.get("strategy", "cot")
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
agent_strategy = agent_dict.get('strategy', 'cot')
|
||||
|
||||
if agent_strategy == 'function_call':
|
||||
if agent_strategy == "function_call":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
elif agent_strategy == 'cot' or agent_strategy == 'react':
|
||||
elif agent_strategy == "cot" or agent_strategy == "react":
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
# old configs, try to detect default strategy
|
||||
if config['model']['provider'] == 'openai':
|
||||
if config["model"]["provider"] == "openai":
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
'provider_type': tool['provider_type'],
|
||||
'provider_id': tool['provider_id'],
|
||||
'tool_name': tool['tool_name'],
|
||||
'tool_parameters': tool.get('tool_parameters', {})
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity(**agent_tool_properties))
|
||||
|
||||
if 'strategy' in config['agent_mode'] and \
|
||||
config['agent_mode']['strategy'] not in ['react_router', 'router']:
|
||||
agent_prompt = agent_dict.get('prompt', None) or {}
|
||||
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [
|
||||
"react_router",
|
||||
"router",
|
||||
]:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
# check model mode
|
||||
model_mode = config.get('model', {}).get('mode', 'completion')
|
||||
if model_mode == 'completion':
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['completion'][
|
||||
'agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["completion"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["completion"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
else:
|
||||
agent_prompt_entity = AgentPromptEntity(
|
||||
first_prompt=agent_prompt.get('first_prompt',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
|
||||
next_iteration=agent_prompt.get('next_iteration',
|
||||
REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
|
||||
first_prompt=agent_prompt.get(
|
||||
"first_prompt", REACT_PROMPT_TEMPLATES["english"]["chat"]["prompt"]
|
||||
),
|
||||
next_iteration=agent_prompt.get(
|
||||
"next_iteration", REACT_PROMPT_TEMPLATES["english"]["chat"]["agent_scratchpad"]
|
||||
),
|
||||
)
|
||||
|
||||
return AgentEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get('max_iteration', 5)
|
||||
max_iteration=agent_dict.get("max_iteration", 5),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -15,39 +15,38 @@ class DatasetConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
dataset_ids = []
|
||||
if 'datasets' in config.get('dataset_configs', {}):
|
||||
datasets = config.get('dataset_configs', {}).get('datasets', {
|
||||
'strategy': 'router',
|
||||
'datasets': []
|
||||
})
|
||||
if "datasets" in config.get("dataset_configs", {}):
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get('datasets', []):
|
||||
for dataset in datasets.get("datasets", []):
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != 'dataset':
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset['dataset']
|
||||
dataset = dataset["dataset"]
|
||||
|
||||
if 'enabled' not in dataset or not dataset['enabled']:
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = dataset.get('id', None)
|
||||
dataset_id = dataset.get("id", None)
|
||||
if dataset_id:
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if 'agent_mode' in config and config['agent_mode'] \
|
||||
and 'enabled' in config['agent_mode'] \
|
||||
and config['agent_mode']['enabled']:
|
||||
if (
|
||||
"agent_mode" in config
|
||||
and config["agent_mode"]
|
||||
and "enabled" in config["agent_mode"]
|
||||
and config["agent_mode"]["enabled"]
|
||||
):
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
agent_dict = config.get('agent_mode', {})
|
||||
|
||||
for tool in agent_dict.get('tools', []):
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != 'dataset':
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
@@ -55,30 +54,28 @@ class DatasetConfigManager:
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
dataset_id = tool_item['id']
|
||||
dataset_id = tool_item["id"]
|
||||
dataset_ids.append(dataset_id)
|
||||
|
||||
if len(dataset_ids) == 0:
|
||||
return None
|
||||
|
||||
# dataset configs
|
||||
if 'dataset_configs' in config and config.get('dataset_configs'):
|
||||
dataset_configs = config.get('dataset_configs')
|
||||
if "dataset_configs" in config and config.get("dataset_configs"):
|
||||
dataset_configs = config.get("dataset_configs")
|
||||
else:
|
||||
dataset_configs = {
|
||||
'retrieval_model': 'multiple'
|
||||
}
|
||||
query_variable = config.get('dataset_query_variable')
|
||||
dataset_configs = {"retrieval_model": "multiple"}
|
||||
query_variable = config.get("dataset_query_variable")
|
||||
|
||||
if dataset_configs['retrieval_model'] == 'single':
|
||||
if dataset_configs["retrieval_model"] == "single":
|
||||
return DatasetEntity(
|
||||
dataset_ids=dataset_ids,
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
)
|
||||
)
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
),
|
||||
)
|
||||
else:
|
||||
return DatasetEntity(
|
||||
@@ -86,15 +83,15 @@ class DatasetConfigManager:
|
||||
retrieve_config=DatasetRetrieveConfigEntity(
|
||||
query_variable=query_variable,
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
dataset_configs["retrieval_model"]
|
||||
),
|
||||
top_k=dataset_configs.get('top_k', 4),
|
||||
score_threshold=dataset_configs.get('score_threshold'),
|
||||
reranking_model=dataset_configs.get('reranking_model'),
|
||||
weights=dataset_configs.get('weights'),
|
||||
reranking_enabled=dataset_configs.get('reranking_enabled', True),
|
||||
rerank_mode=dataset_configs.get('reranking_mode', 'reranking_model'),
|
||||
)
|
||||
top_k=dataset_configs.get("top_k", 4),
|
||||
score_threshold=dataset_configs.get("score_threshold"),
|
||||
reranking_model=dataset_configs.get("reranking_model"),
|
||||
weights=dataset_configs.get("weights"),
|
||||
reranking_enabled=dataset_configs.get("reranking_enabled", True),
|
||||
rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -111,13 +108,10 @@ class DatasetConfigManager:
|
||||
|
||||
# dataset_configs
|
||||
if not config.get("dataset_configs"):
|
||||
config["dataset_configs"] = {'retrieval_model': 'single'}
|
||||
config["dataset_configs"] = {"retrieval_model": "single"}
|
||||
|
||||
if not config["dataset_configs"].get("datasets"):
|
||||
config["dataset_configs"]["datasets"] = {
|
||||
"strategy": "router",
|
||||
"datasets": []
|
||||
}
|
||||
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
|
||||
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
@@ -125,8 +119,9 @@ class DatasetConfigManager:
|
||||
if not isinstance(config["dataset_configs"], dict):
|
||||
raise ValueError("dataset_configs must be of object type")
|
||||
|
||||
need_manual_query_datasets = (config.get("dataset_configs")
|
||||
and config["dataset_configs"].get("datasets", {}).get("datasets"))
|
||||
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
|
||||
"datasets", {}
|
||||
).get("datasets")
|
||||
|
||||
if need_manual_query_datasets and app_mode == AppMode.COMPLETION:
|
||||
# Only check when mode is completion
|
||||
@@ -148,10 +143,7 @@ class DatasetConfigManager:
|
||||
"""
|
||||
# Extract dataset config for legacy compatibility
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
@@ -188,7 +180,7 @@ class DatasetConfigManager:
|
||||
if not isinstance(tool_item["enabled"], bool):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
||||
@@ -11,9 +11,7 @@ from core.provider_manager import ProviderManager
|
||||
|
||||
class ModelConfigConverter:
|
||||
@classmethod
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig,
|
||||
skip_check: bool = False) \
|
||||
-> ModelConfigWithCredentialsEntity:
|
||||
def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity:
|
||||
"""
|
||||
Convert app model config dict to entity.
|
||||
:param app_config: app config
|
||||
@@ -25,9 +23,7 @@ class ModelConfigConverter:
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=app_config.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
@@ -38,8 +34,7 @@ class ModelConfigConverter:
|
||||
|
||||
# check model credentials
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.model
|
||||
model_type=ModelType.LLM, model=model_config.model
|
||||
)
|
||||
|
||||
if model_credentials is None:
|
||||
@@ -51,8 +46,7 @@ class ModelConfigConverter:
|
||||
if not skip_check:
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_config.model,
|
||||
model_type=ModelType.LLM
|
||||
model=model_config.model, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
@@ -69,24 +63,18 @@ class ModelConfigConverter:
|
||||
# model config
|
||||
completion_params = model_config.parameters
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.mode
|
||||
if not model_mode:
|
||||
mode_enum = model_type_instance.get_model_mode(
|
||||
model=model_config.model,
|
||||
credentials=model_credentials
|
||||
)
|
||||
mode_enum = model_type_instance.get_model_mode(model=model_config.model, credentials=model_credentials)
|
||||
|
||||
model_mode = mode_enum.value
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_config.model,
|
||||
model_credentials
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials)
|
||||
|
||||
if not skip_check and not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
|
||||
@@ -13,23 +13,23 @@ class ModelConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# model config
|
||||
model_config = config.get('model')
|
||||
model_config = config.get("model")
|
||||
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get('completion_params')
|
||||
completion_params = model_config.get("completion_params")
|
||||
stop = []
|
||||
if 'stop' in completion_params:
|
||||
stop = completion_params['stop']
|
||||
del completion_params['stop']
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model_config.get('mode')
|
||||
model_mode = model_config.get("mode")
|
||||
|
||||
return ModelConfigEntity(
|
||||
provider=config['model']['provider'],
|
||||
model=config['model']['name'],
|
||||
provider=config["model"]["provider"],
|
||||
model=config["model"]["name"],
|
||||
mode=model_mode,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
@@ -43,7 +43,7 @@ class ModelConfigManager:
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
"""
|
||||
if 'model' not in config:
|
||||
if "model" not in config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
if not isinstance(config["model"], dict):
|
||||
@@ -52,17 +52,16 @@ class ModelConfigManager:
|
||||
# model.provider
|
||||
provider_entities = model_provider_factory.get_providers()
|
||||
model_provider_names = [provider.provider for provider in provider_entities]
|
||||
if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
if "provider" not in config["model"] or config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
# model.name
|
||||
if 'name' not in config["model"]:
|
||||
if "name" not in config["model"]:
|
||||
raise ValueError("model.name is required")
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
models = provider_manager.get_configurations(tenant_id).get_models(
|
||||
provider=config["model"]["provider"],
|
||||
model_type=ModelType.LLM
|
||||
provider=config["model"]["provider"], model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not models:
|
||||
@@ -80,12 +79,12 @@ class ModelConfigManager:
|
||||
|
||||
# model.mode
|
||||
if model_mode:
|
||||
config['model']["mode"] = model_mode
|
||||
config["model"]["mode"] = model_mode
|
||||
else:
|
||||
config['model']["mode"] = "completion"
|
||||
config["model"]["mode"] = "completion"
|
||||
|
||||
# model.completion_params
|
||||
if 'completion_params' not in config["model"]:
|
||||
if "completion_params" not in config["model"]:
|
||||
raise ValueError("model.completion_params is required")
|
||||
|
||||
config["model"]["completion_params"] = cls.validate_model_completion_params(
|
||||
@@ -101,7 +100,7 @@ class ModelConfigManager:
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
# stop
|
||||
if 'stop' not in cp:
|
||||
if "stop" not in cp:
|
||||
cp["stop"] = []
|
||||
elif not isinstance(cp["stop"], list):
|
||||
raise ValueError("stop in model.completion_params must be of list type")
|
||||
|
||||
@@ -14,39 +14,33 @@ class PromptTemplateConfigManager:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config['prompt_type'])
|
||||
prompt_type = PromptTemplateEntity.PromptType.value_of(config["prompt_type"])
|
||||
if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
simple_prompt_template = config.get("pre_prompt", "")
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
simple_prompt_template=simple_prompt_template
|
||||
)
|
||||
return PromptTemplateEntity(prompt_type=prompt_type, simple_prompt_template=simple_prompt_template)
|
||||
else:
|
||||
advanced_chat_prompt_template = None
|
||||
chat_prompt_config = config.get("chat_prompt_config", {})
|
||||
if chat_prompt_config:
|
||||
chat_prompt_messages = []
|
||||
for message in chat_prompt_config.get("prompt", []):
|
||||
chat_prompt_messages.append({
|
||||
"text": message["text"],
|
||||
"role": PromptMessageRole.value_of(message["role"])
|
||||
})
|
||||
chat_prompt_messages.append(
|
||||
{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])}
|
||||
)
|
||||
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
|
||||
messages=chat_prompt_messages
|
||||
)
|
||||
advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages)
|
||||
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
'prompt': completion_prompt_config['prompt']['text'],
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if 'conversation_histories_role' in completion_prompt_config:
|
||||
completion_prompt_template_params['role_prefix'] = {
|
||||
'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
|
||||
'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
@@ -56,7 +50,7 @@ class PromptTemplateConfigManager:
|
||||
return PromptTemplateEntity(
|
||||
prompt_type=prompt_type,
|
||||
advanced_chat_prompt_template=advanced_chat_prompt_template,
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template
|
||||
advanced_completion_prompt_template=advanced_completion_prompt_template,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -72,7 +66,7 @@ class PromptTemplateConfigManager:
|
||||
config["prompt_type"] = PromptTemplateEntity.PromptType.SIMPLE.value
|
||||
|
||||
prompt_type_vals = [typ.value for typ in PromptTemplateEntity.PromptType]
|
||||
if config['prompt_type'] not in prompt_type_vals:
|
||||
if config["prompt_type"] not in prompt_type_vals:
|
||||
raise ValueError(f"prompt_type must be in {prompt_type_vals}")
|
||||
|
||||
# chat_prompt_config
|
||||
@@ -89,27 +83,28 @@ class PromptTemplateConfigManager:
|
||||
if not isinstance(config["completion_prompt_config"], dict):
|
||||
raise ValueError("completion_prompt_config must be of object type")
|
||||
|
||||
if config['prompt_type'] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config['chat_prompt_config'] and not config['completion_prompt_config']:
|
||||
raise ValueError("chat_prompt_config or completion_prompt_config is required "
|
||||
"when prompt_type is advanced")
|
||||
if config["prompt_type"] == PromptTemplateEntity.PromptType.ADVANCED.value:
|
||||
if not config["chat_prompt_config"] and not config["completion_prompt_config"]:
|
||||
raise ValueError(
|
||||
"chat_prompt_config or completion_prompt_config is required " "when prompt_type is advanced"
|
||||
)
|
||||
|
||||
model_mode_vals = [mode.value for mode in ModelMode]
|
||||
if config['model']["mode"] not in model_mode_vals:
|
||||
if config["model"]["mode"] not in model_mode_vals:
|
||||
raise ValueError(f"model.mode must be in {model_mode_vals} when prompt_type is advanced")
|
||||
|
||||
if app_mode == AppMode.CHAT and config['model']["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
|
||||
assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
|
||||
if app_mode == AppMode.CHAT and config["model"]["mode"] == ModelMode.COMPLETION.value:
|
||||
user_prefix = config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"]
|
||||
assistant_prefix = config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"]
|
||||
|
||||
if not user_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["user_prefix"] = "Human"
|
||||
|
||||
if not assistant_prefix:
|
||||
config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
|
||||
config["completion_prompt_config"]["conversation_histories_role"]["assistant_prefix"] = "Assistant"
|
||||
|
||||
if config['model']["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config['chat_prompt_config']['prompt']
|
||||
if config["model"]["mode"] == ModelMode.CHAT.value:
|
||||
prompt_list = config["chat_prompt_config"]["prompt"]
|
||||
|
||||
if len(prompt_list) > 10:
|
||||
raise ValueError("prompt messages must be less than 10")
|
||||
|
||||
@@ -16,32 +16,30 @@ class BasicVariablesConfigManager:
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
external_data_tools = config.get("external_data_tools", [])
|
||||
for external_data_tool in external_data_tools:
|
||||
if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
|
||||
if "enabled" not in external_data_tool or not external_data_tool["enabled"]:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=external_data_tool['variable'],
|
||||
type=external_data_tool['type'],
|
||||
config=external_data_tool['config']
|
||||
variable=external_data_tool["variable"],
|
||||
type=external_data_tool["type"],
|
||||
config=external_data_tool["config"],
|
||||
)
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variables in config.get('user_input_form', []):
|
||||
for variables in config.get("user_input_form", []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
if "config" not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
)
|
||||
)
|
||||
elif variable_type in [
|
||||
@@ -54,13 +52,13 @@ class BasicVariablesConfigManager:
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description"),
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options"),
|
||||
default=variable.get("default"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -103,13 +101,13 @@ class BasicVariablesConfigManager:
|
||||
raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'")
|
||||
|
||||
form_item = item[key]
|
||||
if 'label' not in form_item:
|
||||
if "label" not in form_item:
|
||||
raise ValueError("label is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["label"], str):
|
||||
raise ValueError("label in user_input_form must be of string type")
|
||||
|
||||
if 'variable' not in form_item:
|
||||
if "variable" not in form_item:
|
||||
raise ValueError("variable is required in user_input_form")
|
||||
|
||||
if not isinstance(form_item["variable"], str):
|
||||
@@ -117,26 +115,24 @@ class BasicVariablesConfigManager:
|
||||
|
||||
pattern = re.compile(r"^(?!\d)[\u4e00-\u9fa5A-Za-z0-9_\U0001F300-\U0001F64F\U0001F680-\U0001F6FF]{1,100}$")
|
||||
if pattern.match(form_item["variable"]) is None:
|
||||
raise ValueError("variable in user_input_form must be a string, "
|
||||
"and cannot start with a number")
|
||||
raise ValueError("variable in user_input_form must be a string, " "and cannot start with a number")
|
||||
|
||||
variables.append(form_item["variable"])
|
||||
|
||||
if 'required' not in form_item or not form_item["required"]:
|
||||
if "required" not in form_item or not form_item["required"]:
|
||||
form_item["required"] = False
|
||||
|
||||
if not isinstance(form_item["required"], bool):
|
||||
raise ValueError("required in user_input_form must be of boolean type")
|
||||
|
||||
if key == "select":
|
||||
if 'options' not in form_item or not form_item["options"]:
|
||||
if "options" not in form_item or not form_item["options"]:
|
||||
form_item["options"] = []
|
||||
|
||||
if not isinstance(form_item["options"], list):
|
||||
raise ValueError("options in user_input_form must be a list of strings")
|
||||
|
||||
if "default" in form_item and form_item['default'] \
|
||||
and form_item["default"] not in form_item["options"]:
|
||||
if "default" in form_item and form_item["default"] and form_item["default"] not in form_item["options"]:
|
||||
raise ValueError("default value in user_input_form must be in the options list")
|
||||
|
||||
return config, ["user_input_form"]
|
||||
@@ -168,10 +164,6 @@ class BasicVariablesConfigManager:
|
||||
typ = tool["type"]
|
||||
config = tool["config"]
|
||||
|
||||
ExternalDataToolFactory.validate_config(
|
||||
name=typ,
|
||||
tenant_id=tenant_id,
|
||||
config=config
|
||||
)
|
||||
ExternalDataToolFactory.validate_config(name=typ, tenant_id=tenant_id, config=config)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
|
||||
@@ -12,6 +12,7 @@ class ModelConfigEntity(BaseModel):
|
||||
"""
|
||||
Model Config Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
@@ -23,6 +24,7 @@ class AdvancedChatMessageEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Message Entity.
|
||||
"""
|
||||
|
||||
text: str
|
||||
role: PromptMessageRole
|
||||
|
||||
@@ -31,6 +33,7 @@ class AdvancedChatPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Advanced Chat Prompt Template Entity.
|
||||
"""
|
||||
|
||||
messages: list[AdvancedChatMessageEntity]
|
||||
|
||||
|
||||
@@ -43,6 +46,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel):
|
||||
"""
|
||||
Role Prefix Entity.
|
||||
"""
|
||||
|
||||
user: str
|
||||
assistant: str
|
||||
|
||||
@@ -60,11 +64,12 @@ class PromptTemplateEntity(BaseModel):
|
||||
Prompt Type.
|
||||
'simple', 'advanced'
|
||||
"""
|
||||
SIMPLE = 'simple'
|
||||
ADVANCED = 'advanced'
|
||||
|
||||
SIMPLE = "simple"
|
||||
ADVANCED = "advanced"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'PromptType':
|
||||
def value_of(cls, value: str) -> "PromptType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -74,7 +79,7 @@ class PromptTemplateEntity(BaseModel):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid prompt type value {value}')
|
||||
raise ValueError(f"invalid prompt type value {value}")
|
||||
|
||||
prompt_type: PromptType
|
||||
simple_prompt_template: Optional[str] = None
|
||||
@@ -110,6 +115,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
"""
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
@@ -125,11 +131,12 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
Dataset Retrieve Strategy.
|
||||
'single' or 'multiple'
|
||||
"""
|
||||
SINGLE = 'single'
|
||||
MULTIPLE = 'multiple'
|
||||
|
||||
SINGLE = "single"
|
||||
MULTIPLE = "multiple"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'RetrieveStrategy':
|
||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -139,25 +146,24 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid retrieve strategy value {value}')
|
||||
raise ValueError(f"invalid retrieve strategy value {value}")
|
||||
|
||||
query_variable: Optional[str] = None # Only when app mode is completion
|
||||
|
||||
retrieve_strategy: RetrieveStrategy
|
||||
top_k: Optional[int] = None
|
||||
score_threshold: Optional[float] = .0
|
||||
rerank_mode: Optional[str] = 'reranking_model'
|
||||
score_threshold: Optional[float] = 0.0
|
||||
rerank_mode: Optional[str] = "reranking_model"
|
||||
reranking_model: Optional[dict] = None
|
||||
weights: Optional[dict] = None
|
||||
reranking_enabled: Optional[bool] = True
|
||||
|
||||
|
||||
|
||||
|
||||
class DatasetEntity(BaseModel):
|
||||
"""
|
||||
Dataset Config Entity.
|
||||
"""
|
||||
|
||||
dataset_ids: list[str]
|
||||
retrieve_config: DatasetRetrieveConfigEntity
|
||||
|
||||
@@ -166,6 +172,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
@@ -174,6 +181,7 @@ class TextToSpeechEntity(BaseModel):
|
||||
"""
|
||||
Sensitive Word Avoidance Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
voice: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
@@ -183,12 +191,11 @@ class TracingConfigEntity(BaseModel):
|
||||
"""
|
||||
Tracing Config Entity.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
tracing_provider: str
|
||||
|
||||
|
||||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
@@ -200,10 +207,12 @@ class AppAdditionalFeatures(BaseModel):
|
||||
text_to_speech: Optional[TextToSpeechEntity] = None
|
||||
trace_config: Optional[TracingConfigEntity] = None
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""
|
||||
Application Config Entity.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
@@ -216,15 +225,17 @@ class EasyUIBasedAppModelConfigFrom(Enum):
|
||||
"""
|
||||
App Model Config From.
|
||||
"""
|
||||
ARGS = 'args'
|
||||
APP_LATEST_CONFIG = 'app-latest-config'
|
||||
CONVERSATION_SPECIFIC_CONFIG = 'conversation-specific-config'
|
||||
|
||||
ARGS = "args"
|
||||
APP_LATEST_CONFIG = "app-latest-config"
|
||||
CONVERSATION_SPECIFIC_CONFIG = "conversation-specific-config"
|
||||
|
||||
|
||||
class EasyUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Easy UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
@@ -238,4 +249,5 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
|
||||
workflow_id: str
|
||||
|
||||
@@ -13,21 +13,19 @@ class FileUploadConfigManager:
|
||||
:param config: model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
file_upload_dict = config.get('file_upload')
|
||||
file_upload_dict = config.get("file_upload")
|
||||
if file_upload_dict:
|
||||
if file_upload_dict.get('image'):
|
||||
if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
|
||||
if file_upload_dict.get("image"):
|
||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
||||
image_config = {
|
||||
'number_limits': file_upload_dict['image']['number_limits'],
|
||||
'transfer_methods': file_upload_dict['image']['transfer_methods']
|
||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
||||
}
|
||||
|
||||
if is_vision:
|
||||
image_config['detail'] = file_upload_dict['image']['detail']
|
||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
||||
|
||||
return FileExtraConfig(
|
||||
image_config=image_config
|
||||
)
|
||||
return FileExtraConfig(image_config=image_config)
|
||||
|
||||
return None
|
||||
|
||||
@@ -49,21 +47,21 @@ class FileUploadConfigManager:
|
||||
if not config["file_upload"].get("image"):
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if config["file_upload"]["image"]["enabled"]:
|
||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
if is_vision:
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
detail = config["file_upload"]["image"]["detail"]
|
||||
if detail not in ["high", "low"]:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
if method not in ["remote_url", "local_file"]:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
||||
@@ -7,9 +7,9 @@ class MoreLikeThisConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
more_like_this = False
|
||||
more_like_this_dict = config.get('more_like_this')
|
||||
more_like_this_dict = config.get("more_like_this")
|
||||
if more_like_this_dict:
|
||||
if more_like_this_dict.get('enabled'):
|
||||
if more_like_this_dict.get("enabled"):
|
||||
more_like_this = True
|
||||
|
||||
return more_like_this
|
||||
@@ -22,9 +22,7 @@ class MoreLikeThisConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("more_like_this"):
|
||||
config["more_like_this"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["more_like_this"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["more_like_this"], dict):
|
||||
raise ValueError("more_like_this must be of dict type")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
@@ -9,10 +7,10 @@ class OpeningStatementConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
# opening statement
|
||||
opening_statement = config.get('opening_statement')
|
||||
opening_statement = config.get("opening_statement")
|
||||
|
||||
# suggested questions
|
||||
suggested_questions_list = config.get('suggested_questions')
|
||||
suggested_questions_list = config.get("suggested_questions")
|
||||
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get('retriever_resource')
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
if retriever_resource_dict.get('enabled'):
|
||||
if retriever_resource_dict.get("enabled"):
|
||||
show_retrieve_source = True
|
||||
|
||||
return show_retrieve_source
|
||||
@@ -17,9 +17,7 @@ class RetrievalResourceConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("retriever_resource"):
|
||||
config["retriever_resource"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["retriever_resource"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["retriever_resource"], dict):
|
||||
raise ValueError("retriever_resource must be of dict type")
|
||||
|
||||
@@ -7,9 +7,9 @@ class SpeechToTextConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
speech_to_text = False
|
||||
speech_to_text_dict = config.get('speech_to_text')
|
||||
speech_to_text_dict = config.get("speech_to_text")
|
||||
if speech_to_text_dict:
|
||||
if speech_to_text_dict.get('enabled'):
|
||||
if speech_to_text_dict.get("enabled"):
|
||||
speech_to_text = True
|
||||
|
||||
return speech_to_text
|
||||
@@ -22,9 +22,7 @@ class SpeechToTextConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("speech_to_text"):
|
||||
config["speech_to_text"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["speech_to_text"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["speech_to_text"], dict):
|
||||
raise ValueError("speech_to_text must be of dict type")
|
||||
|
||||
@@ -7,9 +7,9 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
suggested_questions_after_answer = False
|
||||
suggested_questions_after_answer_dict = config.get('suggested_questions_after_answer')
|
||||
suggested_questions_after_answer_dict = config.get("suggested_questions_after_answer")
|
||||
if suggested_questions_after_answer_dict:
|
||||
if suggested_questions_after_answer_dict.get('enabled'):
|
||||
if suggested_questions_after_answer_dict.get("enabled"):
|
||||
suggested_questions_after_answer = True
|
||||
|
||||
return suggested_questions_after_answer
|
||||
@@ -22,15 +22,15 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("suggested_questions_after_answer"):
|
||||
config["suggested_questions_after_answer"] = {
|
||||
"enabled": False
|
||||
}
|
||||
config["suggested_questions_after_answer"] = {"enabled": False}
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"], dict):
|
||||
raise ValueError("suggested_questions_after_answer must be of dict type")
|
||||
|
||||
if "enabled" not in config["suggested_questions_after_answer"] or not \
|
||||
config["suggested_questions_after_answer"]["enabled"]:
|
||||
if (
|
||||
"enabled" not in config["suggested_questions_after_answer"]
|
||||
or not config["suggested_questions_after_answer"]["enabled"]
|
||||
):
|
||||
config["suggested_questions_after_answer"]["enabled"] = False
|
||||
|
||||
if not isinstance(config["suggested_questions_after_answer"]["enabled"], bool):
|
||||
|
||||
@@ -10,13 +10,13 @@ class TextToSpeechConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
text_to_speech = None
|
||||
text_to_speech_dict = config.get('text_to_speech')
|
||||
text_to_speech_dict = config.get("text_to_speech")
|
||||
if text_to_speech_dict:
|
||||
if text_to_speech_dict.get('enabled'):
|
||||
if text_to_speech_dict.get("enabled"):
|
||||
text_to_speech = TextToSpeechEntity(
|
||||
enabled=text_to_speech_dict.get('enabled'),
|
||||
voice=text_to_speech_dict.get('voice'),
|
||||
language=text_to_speech_dict.get('language'),
|
||||
enabled=text_to_speech_dict.get("enabled"),
|
||||
voice=text_to_speech_dict.get("voice"),
|
||||
language=text_to_speech_dict.get("language"),
|
||||
)
|
||||
|
||||
return text_to_speech
|
||||
@@ -29,11 +29,7 @@ class TextToSpeechConfigManager:
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("text_to_speech"):
|
||||
config["text_to_speech"] = {
|
||||
"enabled": False,
|
||||
"voice": "",
|
||||
"language": ""
|
||||
}
|
||||
config["text_to_speech"] = {"enabled": False, "voice": "", "language": ""}
|
||||
|
||||
if not isinstance(config["text_to_speech"], dict):
|
||||
raise ValueError("text_to_speech must be of dict type")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
@@ -19,13 +18,13 @@ class AdvancedChatAppConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Advanced Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
|
||||
features_dict = workflow.features_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
@@ -34,13 +33,9 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
@@ -58,8 +53,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
@@ -69,7 +63,8 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
@@ -86,9 +81,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
@@ -98,4 +91,3 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
|
||||
|
||||
@@ -4,12 +4,10 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@@ -20,20 +18,15 @@ from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGe
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, Workflow
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +34,8 @@ logger = logging.getLogger(__name__)
|
||||
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
@@ -51,7 +45,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
@@ -60,13 +55,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
):
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -77,44 +73,37 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', False)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
conversation_id = args.get("conversation_id")
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
conversation = self._get_conversation_by_user(
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
@@ -136,7 +125,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@@ -146,15 +135,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -166,43 +152,29 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
conversation_id = args.get('conversation_id')
|
||||
if conversation_id:
|
||||
conversation = self._get_conversation_by_user(app_model=app_model, conversation_id=conversation_id, user=user)
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query='',
|
||||
query="",
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@@ -211,32 +183,42 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
stream=stream
|
||||
conversation=None,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate(self, *,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Conversation | None = None,
|
||||
stream: bool = True):
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
conversation: Optional[Conversation] = None,
|
||||
stream: bool = True,
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param invoke_from: invoke from source
|
||||
:param application_generate_entity: application generate entity
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
conversation.override_model_configs = workflow.features
|
||||
db.session.commit()
|
||||
# db.session.refresh(conversation)
|
||||
db.session.refresh(conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -245,73 +227,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == conversation.app_id, ConversationVariable.conversation_id == conversation.id
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=conversation.app_id, conversation_id=conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
conversation.dialogue_count += 1
|
||||
|
||||
conversation_id = conversation.id
|
||||
conversation_dialogue_count = conversation.dialogue_count
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: conversation_id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
contexts.workflow_variable_pool.set(variable_pool)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
'context': contextvars.copy_context(),
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
"context": contextvars.copy_context(),
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -326,16 +256,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
context: contextvars.Context) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
context: contextvars.Context,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -349,40 +280,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
var.set(val)
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
runner = AdvancedChatAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
# get message
|
||||
message = self._get_message(message_id)
|
||||
# get conversation and message
|
||||
conversation = self._get_conversation(conversation_id)
|
||||
message = self._get_message(message_id)
|
||||
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner()
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
)
|
||||
# chatbot app
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG", "false").lower() == 'true':
|
||||
if os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -25,10 +25,7 @@ def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user="responding_tts",
|
||||
tenant_id=tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
|
||||
|
||||
@@ -44,28 +41,26 @@ def _process_future(future_queue, audio_queue):
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b''))
|
||||
audio_queue.put(AudioTrunk("finish", b""))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
|
||||
def __init__(self, tenant_id: str, voice: str):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self.match = re.compile(r'[。.!?]')
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices()
|
||||
values = [voice.get('value') for voice in self.voices]
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get('value')
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event = None
|
||||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
@@ -85,8 +80,9 @@ class AppGeneratorTTSPublisher:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(_invoiceTTS, self.msg_text,
|
||||
self.model_instance, self.tenant_id, self.voice)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
@@ -94,21 +90,20 @@ class AppGeneratorTTSPublisher:
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
self.msg_text += message.event.outputs.get('output', '')
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = ''.join(sentence_arr)
|
||||
futures_result = self.executor.submit(_invoiceTTS, text_content,
|
||||
self.model_instance,
|
||||
self.tenant_id,
|
||||
self.voice)
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ''
|
||||
self.msg_text = ""
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
@@ -1,145 +1,197 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models import App, Message, Workflow
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import ConversationVariable, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdvancedChatAppRunner(AppRunner):
|
||||
class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
AdvancedChat Application Runner
|
||||
"""
|
||||
|
||||
def run(
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
super().__init__(queue_manager)
|
||||
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.conversation = conversation
|
||||
self.message = message
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
queue_manager=queue_manager,
|
||||
app_record=app_record,
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
):
|
||||
return
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
):
|
||||
return
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
query = self.application_generate_entity.query
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# moderation
|
||||
if self.handle_input_moderation(
|
||||
app_record=app_record,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
):
|
||||
return
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=app_record,
|
||||
message=self.message,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
|
||||
# Init conversation variables
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.app_id == self.conversation.app_id,
|
||||
ConversationVariable.conversation_id == self.conversation.id,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
conversation_variables = session.scalars(stmt).all()
|
||||
if not conversation_variables:
|
||||
# Create conversation variables if they don't exist.
|
||||
conversation_variables = [
|
||||
ConversationVariable.from_variable(
|
||||
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
|
||||
)
|
||||
for variable in workflow.conversation_variables
|
||||
]
|
||||
session.add_all(conversation_variables)
|
||||
# Convert database entities to variables.
|
||||
conversation_variables = [item.to_variable() for item in conversation_variables]
|
||||
|
||||
session.commit()
|
||||
|
||||
# Increment dialogue count.
|
||||
self.conversation.dialogue_count += 1
|
||||
|
||||
conversation_dialogue_count = self.conversation.dialogue_count
|
||||
db.session.commit()
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.QUERY: query,
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.CONVERSATION_ID: self.conversation.id,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count,
|
||||
}
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def handle_input_moderation(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_record: App,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
@@ -148,7 +200,6 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
) -> bool:
|
||||
"""
|
||||
Handle input moderation
|
||||
:param queue_manager: application queue manager
|
||||
:param app_record: app record
|
||||
:param app_generate_entity: application generate entity
|
||||
:param inputs: inputs
|
||||
@@ -167,30 +218,19 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=str(e),
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION,
|
||||
)
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(
|
||||
self,
|
||||
app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
) -> bool:
|
||||
"""
|
||||
Handle annotation reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param queue_manager: application queue manager
|
||||
:param app_generate_entity: application generate entity
|
||||
"""
|
||||
# annotation reply
|
||||
@@ -203,37 +243,21 @@ class AdvancedChatAppRunner(AppRunner):
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
self._publish_event(QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id))
|
||||
|
||||
self._stream_output(
|
||||
queue_manager=queue_manager,
|
||||
text=annotation_reply.content,
|
||||
stream=app_generate_entity.stream,
|
||||
stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY,
|
||||
self._complete_with_stream_output(
|
||||
text=annotation_reply.content, stopped_by=QueueStopEvent.StopBy.ANNOTATION_REPLY
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _stream_output(
|
||||
self, queue_manager: AppQueueManager, text: str, stream: bool, stopped_by: QueueStopEvent.StopBy
|
||||
) -> None:
|
||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=token), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
else:
|
||||
queue_manager.publish(QueueTextChunkEvent(text=text), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(QueueTextChunkEvent(text=text))
|
||||
|
||||
queue_manager.publish(QueueStopEvent(stopped_by=stopped_by), PublishFrom.APPLICATION_MANAGER)
|
||||
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
|
||||
|
||||
@@ -28,15 +28,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
"""
|
||||
blocking_response = cast(ChatbotAppBlockingResponse, blocking_response)
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -50,13 +50,15 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -67,14 +69,14 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
@@ -85,7 +87,9 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) -> Generator[str, Any, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@@ -96,20 +100,20 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
||||
@@ -2,9 +2,8 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import contexts
|
||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -22,6 +21,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
@@ -31,34 +33,28 @@ from core.app.entities.queue_entities import (
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
ChatflowStreamGenerateRoute,
|
||||
ErrorStreamResponse,
|
||||
MessageAudioEndStreamResponse,
|
||||
MessageAudioStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, EndUser, Message
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
@@ -69,22 +65,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: AdvancedChatTaskState
|
||||
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: AdvancedChatAppGenerateEntity
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
# Deprecated
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(
|
||||
self, application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
self,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize AdvancedChatAppGenerateTaskPipeline.
|
||||
@@ -106,7 +102,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
self._workflow = workflow
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
# Deprecated
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.QUERY: message.query,
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
@@ -114,12 +109,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = AdvancedChatTaskState(
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._stream_generate_routes = self._get_stream_generate_routes()
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(self):
|
||||
@@ -133,13 +124,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
@@ -156,7 +145,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {}
|
||||
if stream_response.metadata:
|
||||
extras['metadata'] = stream_response.metadata
|
||||
extras["metadata"] = stream_response.metadata
|
||||
|
||||
return ChatbotAppBlockingResponse(
|
||||
task_id=stream_response.task_id,
|
||||
@@ -167,15 +156,17 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.answer,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[ChatbotAppStreamResponse, Any, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@@ -185,7 +176,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
@@ -196,20 +187,24 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@@ -220,9 +215,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# timeout
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@@ -236,38 +231,38 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self,
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if (message.event
|
||||
and getattr(message.event, 'metadata', None)
|
||||
and message.event.metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
elif (hasattr(message.event, 'execution_metadata')
|
||||
and message.event.execution_metadata
|
||||
and message.event.execution_metadata.get('is_answer_previous_node', False)
|
||||
and publisher):
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
# init fake graph runtime state
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
|
||||
self._refetch_message()
|
||||
self._message.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.commit()
|
||||
@@ -275,137 +270,229 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
db.session.close()
|
||||
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id]
|
||||
# reset current route position to 0
|
||||
self._task_state.current_stream_generate_state.current_route_position = 0
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
# stream outputs when node finished
|
||||
generator = self._generate_stream_outputs_when_node_finished()
|
||||
if generator:
|
||||
yield from generator
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent):
|
||||
if workflow_run and graph_runtime_state:
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error=event.get_stop_reason(),
|
||||
conversation_id=self._conversation.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, conversation_id=self._conversation.id, trace_manager=trace_manager
|
||||
)
|
||||
if workflow_run:
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
|
||||
if workflow_run.status == WorkflowRunStatus.FAILED.value:
|
||||
err_event = QueueErrorEvent(error=ValueError(f'Run failed: {workflow_run.error}'))
|
||||
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
|
||||
break
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
else:
|
||||
self._queue_manager.publish(
|
||||
QueueAdvancedChatMessageEndEvent(),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
break
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
self._handle_annotation_reply(event)
|
||||
|
||||
self._refetch_message()
|
||||
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
# published by moderation
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(self._task_state.answer)
|
||||
if output_moderation_answer:
|
||||
self._task_state.answer = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message(graph_runtime_state=graph_runtime_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
else:
|
||||
continue
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
# publish None when task finished
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(self) -> None:
|
||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._refetch_message()
|
||||
|
||||
self._message.answer = self._task_state.answer
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
if self._task_state.metadata and self._task_state.metadata.get('usage'):
|
||||
usage = LLMUsage(**self._task_state.metadata['usage'])
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||
usage = graph_runtime_state.llm_usage
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
@@ -422,7 +509,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
@@ -432,331 +519,15 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
"""
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata.copy()
|
||||
|
||||
if "annotation_reply" in extras["metadata"]:
|
||||
del extras["metadata"]["annotation_reply"]
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
answer_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.ANSWER.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of answer nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in answer_node_configs:
|
||||
# get generate route for stream output
|
||||
answer_node_id = node_config['id']
|
||||
generate_route = AnswerNode.extract_generate_route_selectors(node_config)
|
||||
start_node_ids = self._get_answer_start_at_node_ids(graph, answer_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
|
||||
answer_node_id=answer_node_id,
|
||||
generate_route=generate_route
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_answer_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get answer start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
# check if it's the first node in the iteration
|
||||
target_node = next((node for node in nodes if node.get('id') == target_node_id), None)
|
||||
if not target_node:
|
||||
return []
|
||||
|
||||
node_iteration_id = target_node.get('data', {}).get('iteration_id')
|
||||
# get iteration start node id
|
||||
for node in nodes:
|
||||
if node.get('id') == node_iteration_id:
|
||||
if node.get('data', {}).get('start_node_id') == target_node_id:
|
||||
return [target_node_id]
|
||||
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value,
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_answer_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:
|
||||
]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
break
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self) -> Optional[Generator]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return
|
||||
|
||||
route_chunks = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == 'text':
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
self._task_state.answer += route_chunk.text
|
||||
yield self._message_to_stream_response(route_chunk.text, self._message.id)
|
||||
else:
|
||||
value = None
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
route_chunk_node_id = value_selector[0]
|
||||
|
||||
if route_chunk_node_id == 'sys':
|
||||
# system variable
|
||||
value = contexts.workflow_variable_pool.get().get(value_selector)
|
||||
if value:
|
||||
value = value.text
|
||||
elif route_chunk_node_id in self._iteration_nested_relations:
|
||||
# it's a iteration variable
|
||||
if not self._iteration_state or route_chunk_node_id not in self._iteration_state.current_iterations:
|
||||
continue
|
||||
iteration_state = self._iteration_state.current_iterations[route_chunk_node_id]
|
||||
iterator = iteration_state.inputs
|
||||
if not iterator:
|
||||
continue
|
||||
iterator_selector = iterator.get('iterator_selector', [])
|
||||
if value_selector[1] == 'index':
|
||||
value = iteration_state.current_index
|
||||
elif value_selector[1] == 'item':
|
||||
value = iterator_selector[iteration_state.current_index] if iteration_state.current_index < len(
|
||||
iterator_selector
|
||||
) else None
|
||||
else:
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if route_chunk_node_id not in self._task_state.ran_node_execution_infos:
|
||||
break
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
|
||||
# get route chunk node execution info
|
||||
route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id]
|
||||
if (route_chunk_node_execution_info.node_type == NodeType.LLM
|
||||
and latest_node_execution_info.node_type == NodeType.LLM):
|
||||
# only LLM support chunk stream output
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
continue
|
||||
|
||||
# get route chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id
|
||||
).first()
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
# get value from outputs
|
||||
value = None
|
||||
for key in value_selector[1:]:
|
||||
if not value:
|
||||
value = outputs.get(key) if outputs else None
|
||||
else:
|
||||
value = value.get(key)
|
||||
|
||||
if value is not None:
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._message_to_stream_response(text, self._message.id)
|
||||
|
||||
self._task_state.current_stream_generate_state.current_route_position += 1
|
||||
|
||||
# all route chunks are generated
|
||||
if self._task_state.current_stream_generate_state.current_route_position == len(
|
||||
self._task_state.current_stream_generate_state.generate_route
|
||||
):
|
||||
self._task_state.current_stream_generate_state = None
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return True
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return True
|
||||
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
route_chunk = self._task_state.current_stream_generate_state.generate_route[
|
||||
self._task_state.current_stream_generate_state.current_route_position]
|
||||
|
||||
if route_chunk.type != 'var':
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
@@ -768,17 +539,23 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.answer = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=self._task_state.answer
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
||||
|
||||
def _refetch_message(self) -> None:
|
||||
"""
|
||||
Refetch message.
|
||||
:return:
|
||||
"""
|
||||
message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
if message:
|
||||
self._message = message
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
@@ -28,15 +28,19 @@ class AgentChatAppConfig(EasyUIBasedAppConfig):
|
||||
"""
|
||||
Agent Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
agent: Optional[AgentEntity] = None
|
||||
|
||||
|
||||
class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> AgentChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
:param app_model: app model
|
||||
@@ -66,22 +70,12 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
agent=AgentConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
agent=AgentConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
@@ -128,7 +122,8 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
@@ -145,13 +140,15 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# dataset configs
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
@@ -170,10 +167,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
:param config: app model config args
|
||||
"""
|
||||
if not config.get("agent_mode"):
|
||||
config["agent_mode"] = {
|
||||
"enabled": False,
|
||||
"tools": []
|
||||
}
|
||||
config["agent_mode"] = {"enabled": False, "tools": []}
|
||||
|
||||
if not isinstance(config["agent_mode"], dict):
|
||||
raise ValueError("agent_mode must be of object type")
|
||||
@@ -187,8 +181,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
if not config["agent_mode"].get("strategy"):
|
||||
config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value
|
||||
|
||||
if config["agent_mode"]["strategy"] not in [member.value for member in
|
||||
list(PlanningStrategy.__members__.values())]:
|
||||
if config["agent_mode"]["strategy"] not in [
|
||||
member.value for member in list(PlanningStrategy.__members__.values())
|
||||
]:
|
||||
raise ValueError("strategy in agent_mode must be in the specified strategy list")
|
||||
|
||||
if not config["agent_mode"].get("tools"):
|
||||
@@ -210,7 +205,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
raise ValueError("enabled in agent_mode.tools must be of boolean type")
|
||||
|
||||
if key == "dataset":
|
||||
if 'id' not in tool_item:
|
||||
if "id" not in tool_item:
|
||||
raise ValueError("id is required in dataset")
|
||||
|
||||
try:
|
||||
|
||||
@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
||||
class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -39,19 +40,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[dict, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
) -> Union[dict, Generator[dict, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -62,60 +61,48 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not stream:
|
||||
raise ValueError('Agent Chat App does not support blocking mode')
|
||||
raise ValueError("Agent Chat App does not support blocking mode")
|
||||
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = AgentChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -124,7 +111,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
@@ -145,14 +132,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
call_depth=0,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -161,17 +145,20 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -185,13 +172,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return AgentChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(
|
||||
self, flask_app: Flask,
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
@@ -224,14 +209,13 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,7 +30,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
"""
|
||||
|
||||
def run(
|
||||
self, application_generate_entity: AgentChatAppGenerateEntity,
|
||||
self,
|
||||
application_generate_entity: AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
@@ -65,7 +66,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
@@ -73,13 +74,10 @@ class AgentChatAppRunner(AppRunner):
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
@@ -91,7 +89,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
@@ -103,7 +101,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -111,7 +109,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -122,13 +120,13 @@ class AgentChatAppRunner(AppRunner):
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
@@ -136,7 +134,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -148,7 +146,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
@@ -161,14 +159,14 @@ class AgentChatAppRunner(AppRunner):
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
@@ -177,9 +175,9 @@ class AgentChatAppRunner(AppRunner):
|
||||
agent_entity = app_config.agent
|
||||
|
||||
# load tool variables
|
||||
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
tenant_id=app_config.tenant_id)
|
||||
tool_conversation_variables = self._load_tool_variables(
|
||||
conversation_id=conversation.id, user_id=application_generate_entity.user_id, tenant_id=app_config.tenant_id
|
||||
)
|
||||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
@@ -187,7 +185,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
prompt_message, _ = self.organize_prompt_messages(
|
||||
app_record=app_record,
|
||||
@@ -238,7 +236,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
prompt_messages=prompt_message,
|
||||
variables_pool=tool_variables,
|
||||
db_variables=tool_conversation_variables,
|
||||
model_instance=model_instance
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
invoke_result = runner.run(
|
||||
@@ -252,17 +250,21 @@ class AgentChatAppRunner(AppRunner):
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream,
|
||||
agent=True
|
||||
agent=True,
|
||||
)
|
||||
|
||||
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
|
||||
"""
|
||||
load tool variables from database
|
||||
"""
|
||||
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id
|
||||
).first()
|
||||
tool_variables: ToolConversationVariables = (
|
||||
db.session.query(ToolConversationVariables)
|
||||
.filter(
|
||||
ToolConversationVariables.conversation_id == conversation_id,
|
||||
ToolConversationVariables.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_variables:
|
||||
# save tool variables to session, so that we can update it later
|
||||
@@ -273,34 +275,40 @@ class AgentChatAppRunner(AppRunner):
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
variables_str='[]',
|
||||
variables_str="[]",
|
||||
)
|
||||
db.session.add(tool_variables)
|
||||
db.session.commit()
|
||||
|
||||
return tool_variables
|
||||
|
||||
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
|
||||
|
||||
def _convert_db_variables_to_tool_variables(
|
||||
self, db_variables: ToolConversationVariables
|
||||
) -> ToolRuntimeVariablePool:
|
||||
"""
|
||||
convert db variables to tool variables
|
||||
"""
|
||||
return ToolRuntimeVariablePool(**{
|
||||
'conversation_id': db_variables.conversation_id,
|
||||
'user_id': db_variables.user_id,
|
||||
'tenant_id': db_variables.tenant_id,
|
||||
'pool': db_variables.variables
|
||||
})
|
||||
return ToolRuntimeVariablePool(
|
||||
**{
|
||||
"conversation_id": db_variables.conversation_id,
|
||||
"user_id": db_variables.user_id,
|
||||
"tenant_id": db_variables.tenant_id,
|
||||
"pool": db_variables.variables,
|
||||
}
|
||||
)
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
def _get_usage_of_all_agent_thoughts(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, message: Message
|
||||
) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
agent_thoughts = (
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).all()
|
||||
)
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
@@ -312,8 +320,5 @@ class AgentChatAppRunner(AppRunner):
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
model_config.model, model_config.credentials, all_message_tokens, all_answer_tokens
|
||||
)
|
||||
|
||||
@@ -23,15 +23,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -45,14 +45,15 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -63,14 +64,14 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
@@ -81,8 +82,9 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@@ -93,20 +95,20 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
||||
@@ -13,32 +13,33 @@ class AppGenerateResponseConverter(ABC):
|
||||
_blocking_response_type: type[AppBlockingResponse]
|
||||
|
||||
@classmethod
|
||||
def convert(cls, response: Union[
|
||||
AppBlockingResponse,
|
||||
Generator[AppStreamResponse, Any, None]
|
||||
], invoke_from: InvokeFrom):
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
|
||||
def _generate_full_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_full_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
|
||||
def _generate_simple_response() -> Generator[str, Any, None]:
|
||||
for chunk in cls.convert_stream_simple_response(response):
|
||||
if chunk == 'ping':
|
||||
yield f'event: {chunk}\n\n'
|
||||
if chunk == "ping":
|
||||
yield f"event: {chunk}\n\n"
|
||||
else:
|
||||
yield f'data: {chunk}\n\n'
|
||||
yield f"data: {chunk}\n\n"
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@@ -54,14 +55,16 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[AppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@@ -72,24 +75,26 @@ class AppGenerateResponseConverter(ABC):
|
||||
:return:
|
||||
"""
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in metadata:
|
||||
metadata['retriever_resources'] = []
|
||||
for resource in metadata['retriever_resources']:
|
||||
metadata['retriever_resources'].append({
|
||||
'segment_id': resource['segment_id'],
|
||||
'position': resource['position'],
|
||||
'document_name': resource['document_name'],
|
||||
'score': resource['score'],
|
||||
'content': resource['content'],
|
||||
})
|
||||
if "retriever_resources" in metadata:
|
||||
metadata["retriever_resources"] = []
|
||||
for resource in metadata["retriever_resources"]:
|
||||
metadata["retriever_resources"].append(
|
||||
{
|
||||
"segment_id": resource["segment_id"],
|
||||
"position": resource["position"],
|
||||
"document_name": resource["document_name"],
|
||||
"score": resource["score"],
|
||||
"content": resource["content"],
|
||||
}
|
||||
)
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in metadata:
|
||||
del metadata['annotation_reply']
|
||||
if "annotation_reply" in metadata:
|
||||
del metadata["annotation_reply"]
|
||||
|
||||
# show usage
|
||||
if 'usage' in metadata:
|
||||
del metadata['usage']
|
||||
if "usage" in metadata:
|
||||
del metadata["usage"]
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -101,16 +106,16 @@ class AppGenerateResponseConverter(ABC):
|
||||
:return:
|
||||
"""
|
||||
error_responses = {
|
||||
ValueError: {'code': 'invalid_param', 'status': 400},
|
||||
ProviderTokenNotInitError: {'code': 'provider_not_initialize', 'status': 400},
|
||||
ValueError: {"code": "invalid_param", "status": 400},
|
||||
ProviderTokenNotInitError: {"code": "provider_not_initialize", "status": 400},
|
||||
QuotaExceededError: {
|
||||
'code': 'provider_quota_exceeded',
|
||||
'message': "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
'status': 400
|
||||
"code": "provider_quota_exceeded",
|
||||
"message": "Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.",
|
||||
"status": 400,
|
||||
},
|
||||
ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400},
|
||||
InvokeError: {'code': 'completion_request_error', 'status': 400}
|
||||
ModelCurrentlyNotSupportError: {"code": "model_currently_not_support", "status": 400},
|
||||
InvokeError: {"code": "completion_request_error", "status": 400},
|
||||
}
|
||||
|
||||
# Determine the response based on the type of exception
|
||||
@@ -120,13 +125,13 @@ class AppGenerateResponseConverter(ABC):
|
||||
data = v
|
||||
|
||||
if data:
|
||||
data.setdefault('message', getattr(e, 'description', str(e)))
|
||||
data.setdefault("message", getattr(e, "description", str(e)))
|
||||
else:
|
||||
logging.error(e)
|
||||
data = {
|
||||
'code': 'internal_server_error',
|
||||
'message': 'Internal Server Error, please contact support.',
|
||||
'status': 500
|
||||
"code": "internal_server_error",
|
||||
"message": "Internal Server Error, please contact support.",
|
||||
"status": 500,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
@@ -16,10 +16,10 @@ class BaseAppGenerator:
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
return var.default or ""
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
@@ -34,7 +34,7 @@ class BaseAppGenerator:
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
@@ -43,14 +43,14 @@ class BaseAppGenerator:
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
|
||||
return user_input_value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.replace('\x00', '')
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
@@ -24,9 +24,7 @@ class PublishFrom(Enum):
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None:
|
||||
if not user_id:
|
||||
raise ValueError("user is required")
|
||||
|
||||
@@ -34,9 +32,10 @@ class AppQueueManager:
|
||||
self._user_id = user_id
|
||||
self._invoke_from = invoke_from
|
||||
|
||||
user_prefix = 'account' if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
redis_client.setex(AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800,
|
||||
f"{user_prefix}-{self._user_id}")
|
||||
user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
||||
redis_client.setex(
|
||||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
|
||||
@@ -66,8 +65,7 @@ class AppQueueManager:
|
||||
# publish two messages to make sure the client can receive the stop signal
|
||||
# and stop listening after the stop signal processed
|
||||
self.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
if elapsed_time // 10 > last_ping_time:
|
||||
@@ -88,9 +86,7 @@ class AppQueueManager:
|
||||
:param pub_from: publish from
|
||||
:return:
|
||||
"""
|
||||
self.publish(QueueErrorEvent(
|
||||
error=e
|
||||
), pub_from)
|
||||
self.publish(QueueErrorEvent(error=e), pub_from)
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
@@ -122,8 +118,8 @@ class AppQueueManager:
|
||||
if result is None:
|
||||
return
|
||||
|
||||
user_prefix = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end-user'
|
||||
if result.decode('utf-8') != f"{user_prefix}-{user_id}":
|
||||
user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user"
|
||||
if result.decode("utf-8") != f"{user_prefix}-{user_id}":
|
||||
return
|
||||
|
||||
stopped_cache_key = cls._generate_stopped_cache_key(task_id)
|
||||
@@ -168,9 +164,11 @@ class AppQueueManager:
|
||||
for item in data:
|
||||
self._check_for_sqlalchemy_models(item)
|
||||
else:
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, '_sa_instance_state'):
|
||||
raise TypeError("Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed.")
|
||||
if isinstance(data, DeclarativeMeta) or hasattr(data, "_sa_instance_state"):
|
||||
raise TypeError(
|
||||
"Critical Error: Passing SQLAlchemy Model instances "
|
||||
"that cause thread safety issues is not allowed."
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
@@ -31,12 +31,15 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None) -> int:
|
||||
def get_pre_calculate_rest_tokens(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
@@ -49,18 +52,20 @@ class AppRunner:
|
||||
"""
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
@@ -75,36 +80,39 @@ class AppRunner:
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
raise InvokeBadRequestError(
|
||||
"Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size."
|
||||
)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recalc_llm_max_tokens(self, model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_messages: list[PromptMessage]):
|
||||
def recalc_llm_max_tokens(
|
||||
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
|
||||
):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)
|
||||
) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
@@ -112,27 +120,28 @@ class AppRunner:
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def organize_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
def organize_prompt_messages(
|
||||
self,
|
||||
app_record: App,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list["FileVar"],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None,
|
||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
@@ -152,60 +161,54 @@ class AppRunner:
|
||||
app_mode=AppMode.value_of(app_record.mode),
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
model_mode = ModelMode.value_of(model_config.mode)
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
|
||||
prompt_template = CompletionModelPromptTemplate(
|
||||
text=advanced_completion_prompt_template.prompt
|
||||
)
|
||||
prompt_template = CompletionModelPromptTemplate(text=advanced_completion_prompt_template.prompt)
|
||||
|
||||
if advanced_completion_prompt_template.role_prefix:
|
||||
memory_config.role_prefix = MemoryConfig.RolePrefix(
|
||||
user=advanced_completion_prompt_template.role_prefix.user,
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant
|
||||
assistant=advanced_completion_prompt_template.role_prefix.assistant,
|
||||
)
|
||||
else:
|
||||
prompt_template = []
|
||||
for message in prompt_template_entity.advanced_chat_prompt_template.messages:
|
||||
prompt_template.append(ChatModelMessage(
|
||||
text=message.text,
|
||||
role=message.role
|
||||
))
|
||||
prompt_template.append(ChatModelMessage(text=message.text, role=message.role))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
files=files,
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
model_config=model_config,
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
def direct_output(
|
||||
self,
|
||||
queue_manager: AppQueueManager,
|
||||
app_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
@@ -222,17 +225,10 @@ class AppRunner:
|
||||
chunk = LLMResultChunk(
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
delta=LLMResultChunkDelta(index=index, message=AssistantPromptMessage(content=token)),
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=chunk
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=chunk), PublishFrom.APPLICATION_MANAGER)
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
@@ -242,15 +238,19 @@ class AppRunner:
|
||||
model=app_generate_entity.model_conf.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
usage=usage if usage else LLMUsage.empty_usage(),
|
||||
),
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False) -> None:
|
||||
def _handle_invoke_result(
|
||||
self,
|
||||
invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: AppQueueManager,
|
||||
stream: bool,
|
||||
agent: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -260,21 +260,13 @@ class AppRunner:
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
agent=agent
|
||||
)
|
||||
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_direct(
|
||||
self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
@@ -285,12 +277,13 @@ class AppRunner:
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=invoke_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: AppQueueManager,
|
||||
agent: bool) -> None:
|
||||
def _handle_invoke_result_stream(
|
||||
self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool
|
||||
) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
@@ -300,21 +293,13 @@ class AppRunner:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
text = ""
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
if not agent:
|
||||
queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueLLMChunkEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
queue_manager.publish(
|
||||
QueueAgentMessageEvent(
|
||||
chunk=result
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
@@ -331,25 +316,24 @@ class AppRunner:
|
||||
usage = LLMUsage.empty_usage()
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
model=model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=llm_result,
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def moderation_for_inputs(
|
||||
self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
message_id: str,
|
||||
self,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
app_generate_entity: AppGenerateEntity,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
@@ -367,14 +351,17 @@ class AppRunner:
|
||||
tenant_id=tenant_id,
|
||||
app_config=app_generate_entity.app_config,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
query=query if query else "",
|
||||
message_id=message_id,
|
||||
trace_manager=app_generate_entity.trace_manager
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
def check_hosting_moderation(
|
||||
self,
|
||||
application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -384,8 +371,7 @@ class AppRunner:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
application_generate_entity=application_generate_entity, prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
@@ -393,18 +379,20 @@ class AppRunner:
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
def fill_in_inputs_from_external_data_tools(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
@@ -417,18 +405,12 @@ class AppRunner:
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetch()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
tenant_id=tenant_id, app_id=app_id, external_data_tools=external_data_tools, inputs=inputs, query=query
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
def query_app_annotations_to_reply(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
@@ -440,9 +422,5 @@ class AppRunner:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
|
||||
)
|
||||
|
||||
@@ -22,15 +22,19 @@ class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
"""
|
||||
Chatbot App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None) -> ChatAppConfig:
|
||||
def get_app_config(
|
||||
cls,
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Optional[Conversation] = None,
|
||||
override_config_dict: Optional[dict] = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
:param app_model: app model
|
||||
@@ -51,7 +55,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
if not override_config_dict:
|
||||
raise Exception('override_config_dict is required when config_from is ARGS')
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
|
||||
config_dict = override_config_dict
|
||||
|
||||
@@ -63,19 +67,11 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
@@ -113,8 +109,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
@@ -123,7 +120,8 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# suggested_questions_after_answer
|
||||
config, current_related_config_keys = SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults(
|
||||
config)
|
||||
config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# speech_to_text
|
||||
@@ -139,8 +137,9 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
@@ -30,7 +30,8 @@ logger = logging.getLogger(__name__)
|
||||
class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -39,7 +40,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -47,7 +49,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -62,58 +65,46 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not args.get('query'):
|
||||
raise ValueError('query is required')
|
||||
if not args.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": args.get('auto_generate_name', True)
|
||||
}
|
||||
extras = {"auto_generate_conversation_name": args.get("auto_generate_name", True)}
|
||||
|
||||
# get conversation
|
||||
conversation = None
|
||||
if args.get('conversation_id'):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get('conversation_id'), user)
|
||||
if args.get("conversation_id"):
|
||||
conversation = self._get_conversation_by_user(app_model, args.get("conversation_id"), user)
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = ChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {
|
||||
"enabled": True
|
||||
}
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
@@ -122,7 +113,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation,
|
||||
override_config_dict=override_model_config_dict
|
||||
override_config_dict=override_model_config_dict,
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
@@ -141,14 +132,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity, conversation)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -157,17 +145,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'conversation_id': conversation.id,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"conversation_id": conversation.id,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -181,16 +172,16 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return ChatAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -212,20 +203,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,10 +24,13 @@ class ChatAppRunner(AppRunner):
|
||||
Chat Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self,
|
||||
application_generate_entity: ChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -58,7 +61,7 @@ class ChatAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
memory = None
|
||||
@@ -66,13 +69,10 @@ class ChatAppRunner(AppRunner):
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
@@ -84,7 +84,7 @@ class ChatAppRunner(AppRunner):
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# moderation
|
||||
@@ -96,7 +96,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -104,7 +104,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -115,13 +115,13 @@ class ChatAppRunner(AppRunner):
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish(
|
||||
QueueAnnotationReplyEvent(message_annotation_id=annotation_reply.id),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
self.direct_output(
|
||||
@@ -129,7 +129,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -141,7 +141,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
@@ -152,7 +152,7 @@ class ChatAppRunner(AppRunner):
|
||||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval(application_generate_entity)
|
||||
@@ -181,29 +181,26 @@ class ChatAppRunner(AppRunner):
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
@@ -218,7 +215,5 @@ class ChatAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
@@ -23,15 +23,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'conversation_id': blocking_response.data.conversation_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"conversation_id": blocking_response.data.conversation_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -45,14 +45,15 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -63,14 +64,14 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
@@ -81,8 +82,9 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[ChatbotAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@@ -93,20 +95,20 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'conversation_id': chunk.conversation_id,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"conversation_id": chunk.conversation_id,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
||||
@@ -17,14 +17,15 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
"""
|
||||
Completion App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(cls, app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
override_config_dict: Optional[dict] = None) -> CompletionAppConfig:
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: Optional[dict] = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
:param app_model: app model
|
||||
@@ -51,19 +52,11 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
model=ModelConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
prompt_template=PromptTemplateConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
dataset=DatasetConfigManager.convert(
|
||||
config=config_dict
|
||||
),
|
||||
additional_features=cls.convert_features(config_dict, app_mode)
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
dataset=DatasetConfigManager.convert(config=config_dict),
|
||||
additional_features=cls.convert_features(config_dict, app_mode),
|
||||
)
|
||||
|
||||
app_config.variables, app_config.external_data_variables = BasicVariablesConfigManager.convert(
|
||||
@@ -101,8 +94,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# dataset_query_variable
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(tenant_id, app_mode,
|
||||
config)
|
||||
config, current_related_config_keys = DatasetConfigManager.validate_and_set_defaults(
|
||||
tenant_id, app_mode, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
@@ -114,8 +108,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(tenant_id,
|
||||
config)
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id, config
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
||||
class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -41,19 +42,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(self, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
def generate(
|
||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -63,12 +62,12 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
if not isinstance(query, str):
|
||||
raise ValueError('query must be a string')
|
||||
raise ValueError("query must be a string")
|
||||
|
||||
query = query.replace('\x00', '')
|
||||
inputs = args['inputs']
|
||||
query = query.replace("\x00", "")
|
||||
inputs = args["inputs"]
|
||||
|
||||
extras = {}
|
||||
|
||||
@@ -76,41 +75,31 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
conversation = None
|
||||
|
||||
# get app model config
|
||||
app_model_config = self._get_app_model_config(
|
||||
app_model=app_model,
|
||||
conversation=conversation
|
||||
)
|
||||
app_model_config = self._get_app_model_config(app_model=app_model, conversation=conversation)
|
||||
|
||||
# validate override model config
|
||||
override_model_config_dict = None
|
||||
if args.get('model_config'):
|
||||
if args.get("model_config"):
|
||||
if invoke_from != InvokeFrom.DEBUGGER:
|
||||
raise ValueError('Only in App debug mode can override model config')
|
||||
raise ValueError("Only in App debug mode can override model config")
|
||||
|
||||
# validate config
|
||||
override_model_config_dict = CompletionAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=args.get('model_config')
|
||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
@@ -128,14 +117,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -144,16 +130,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -167,15 +156,15 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
@@ -194,20 +183,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message=message
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
@@ -216,12 +204,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def generate_more_like_this(self, app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True) \
|
||||
-> Union[dict, Generator[str, None, None]]:
|
||||
def generate_more_like_this(
|
||||
self,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -231,13 +221,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
@@ -250,29 +244,23 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
app_model_config = message.app_model_config
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict['model']
|
||||
completion_params = model_dict.get('completion_params')
|
||||
completion_params['temperature'] = 0.9
|
||||
model_dict['completion_params'] = completion_params
|
||||
override_model_config_dict['model'] = model_dict
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
message.files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = CompletionAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
app_model_config=app_model_config,
|
||||
override_config_dict=override_model_config_dict
|
||||
app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict
|
||||
)
|
||||
|
||||
# init application generate entity
|
||||
@@ -286,14 +274,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
extras={}
|
||||
extras={},
|
||||
)
|
||||
|
||||
# init generate records
|
||||
(
|
||||
conversation,
|
||||
message
|
||||
) = self._init_generate_records(application_generate_entity)
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -302,16 +287,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'message_id': message.id,
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"message_id": message.id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -325,7 +313,4 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return CompletionAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
@@ -22,9 +22,9 @@ class CompletionAppRunner(AppRunner):
|
||||
Completion Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: CompletionAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
message: Message) -> None:
|
||||
def run(
|
||||
self, application_generate_entity: CompletionAppGenerateEntity, queue_manager: AppQueueManager, message: Message
|
||||
) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -54,7 +54,7 @@ class CompletionAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
@@ -65,7 +65,7 @@ class CompletionAppRunner(AppRunner):
|
||||
prompt_template_entity=app_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# moderation
|
||||
@@ -77,7 +77,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
@@ -85,7 +85,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
stream=application_generate_entity.stream,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -97,7 +97,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
query=query,
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
@@ -108,7 +108,7 @@ class CompletionAppRunner(AppRunner):
|
||||
app_record.id,
|
||||
message.id,
|
||||
application_generate_entity.user_id,
|
||||
application_generate_entity.invoke_from
|
||||
application_generate_entity.invoke_from,
|
||||
)
|
||||
|
||||
dataset_config = app_config.dataset
|
||||
@@ -126,7 +126,7 @@ class CompletionAppRunner(AppRunner):
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
@@ -139,29 +139,26 @@ class CompletionAppRunner(AppRunner):
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context
|
||||
context=context,
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recalc_llm_max_tokens(
|
||||
model_config=application_generate_entity.model_conf,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
self.recalc_llm_max_tokens(model_config=application_generate_entity.model_conf, prompt_messages=prompt_messages)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=application_generate_entity.model_conf.provider_model_bundle,
|
||||
model=application_generate_entity.model_conf.model
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
@@ -176,8 +173,5 @@ class CompletionAppRunner(AppRunner):
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
@@ -23,14 +23,14 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': blocking_response.task_id,
|
||||
'id': blocking_response.data.id,
|
||||
'message_id': blocking_response.data.message_id,
|
||||
'mode': blocking_response.data.mode,
|
||||
'answer': blocking_response.data.answer,
|
||||
'metadata': blocking_response.data.metadata,
|
||||
'created_at': blocking_response.data.created_at
|
||||
"event": "message",
|
||||
"task_id": blocking_response.task_id,
|
||||
"id": blocking_response.data.id,
|
||||
"message_id": blocking_response.data.message_id,
|
||||
"mode": blocking_response.data.mode,
|
||||
"answer": blocking_response.data.answer,
|
||||
"metadata": blocking_response.data.metadata,
|
||||
"created_at": blocking_response.data.created_at,
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -44,14 +44,15 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
"""
|
||||
response = cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
metadata = response.get('metadata', {})
|
||||
response['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = response.get("metadata", {})
|
||||
response["metadata"] = cls._get_simple_metadata(metadata)
|
||||
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -62,13 +63,13 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
@@ -79,8 +80,9 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[CompletionAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[CompletionAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@@ -91,19 +93,19 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'message_id': chunk.message_id,
|
||||
'created_at': chunk.created_at
|
||||
"event": sub_stream_response.event.value,
|
||||
"message_id": chunk.message_id,
|
||||
"created_at": chunk.created_at,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.to_dict()
|
||||
metadata = sub_stream_response_dict.get('metadata', {})
|
||||
sub_stream_response_dict['metadata'] = cls._get_simple_metadata(metadata)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
|
||||
@@ -35,23 +35,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
def _handle_response(
|
||||
self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Handle response.
|
||||
@@ -70,7 +70,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -82,12 +82,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
def _get_conversation_by_user(self, app_model: App, conversation_id: str,
|
||||
user: Union[Account, EndUser]) -> Conversation:
|
||||
def _get_conversation_by_user(
|
||||
self, app_model: App, conversation_id: str, user: Union[Account, EndUser]
|
||||
) -> Conversation:
|
||||
conversation_filter = [
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.status == 'normal'
|
||||
Conversation.status == "normal",
|
||||
]
|
||||
|
||||
if isinstance(user, Account):
|
||||
@@ -100,19 +101,18 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
return conversation
|
||||
|
||||
def _get_app_model_config(self, app_model: App,
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> AppModelConfig:
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
@@ -127,15 +127,16 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
|
||||
return app_model_config
|
||||
|
||||
def _init_generate_records(self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
],
|
||||
conversation: Optional[Conversation] = None) \
|
||||
-> tuple[Conversation, Message]:
|
||||
def _init_generate_records(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
conversation: Optional[Conversation] = None,
|
||||
) -> tuple[Conversation, Message]:
|
||||
"""
|
||||
Initialize generate records
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -148,10 +149,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
end_user_id = None
|
||||
account_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
from_source = 'api'
|
||||
from_source = "api"
|
||||
end_user_id = application_generate_entity.user_id
|
||||
else:
|
||||
from_source = 'console'
|
||||
from_source = "console"
|
||||
account_id = application_generate_entity.user_id
|
||||
|
||||
if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
|
||||
@@ -164,8 +165,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
model_provider = application_generate_entity.model_conf.provider
|
||||
model_id = application_generate_entity.model_conf.model
|
||||
override_model_configs = None
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS \
|
||||
and app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]:
|
||||
if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT,
|
||||
AppMode.COMPLETION,
|
||||
]:
|
||||
override_model_configs = app_config.app_model_config_dict
|
||||
|
||||
# get conversation introduction
|
||||
@@ -179,12 +183,12 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
model_id=model_id,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=app_config.app_mode.value,
|
||||
name='New conversation',
|
||||
name="New conversation",
|
||||
inputs=application_generate_entity.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status='normal',
|
||||
status="normal",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
@@ -216,11 +220,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
invoke_from=application_generate_entity.invoke_from.value,
|
||||
from_source=from_source,
|
||||
from_end_user_id=end_user_id,
|
||||
from_account_id=account_id
|
||||
from_account_id=account_id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
@@ -232,10 +236,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
message_id=message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
belongs_to='user',
|
||||
belongs_to="user",
|
||||
url=file.url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=('account' if account_id else 'end_user'),
|
||||
created_by_role=("account" if account_id else "end_user"),
|
||||
created_by=account_id or end_user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
@@ -269,11 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
@@ -286,10 +286,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.id == message_id)
|
||||
.first()
|
||||
)
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
|
||||
return message
|
||||
|
||||
@@ -12,12 +12,9 @@ from core.app.entities.queue_entities import (
|
||||
|
||||
|
||||
class MessageBasedAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
conversation_id: str,
|
||||
app_mode: str,
|
||||
message_id: str) -> None:
|
||||
def __init__(
|
||||
self, task_id: str, user_id: str, invoke_from: InvokeFrom, conversation_id: str, app_mode: str, message_id: str
|
||||
) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._conversation_id = str(conversation_id)
|
||||
@@ -30,7 +27,7 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
@@ -45,17 +42,15 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
message_id=self._message_id,
|
||||
conversation_id=self._conversation_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
event=event,
|
||||
)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueAdvancedChatMessageEndEvent):
|
||||
if isinstance(
|
||||
event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent | QueueAdvancedChatMessageEndEvent
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ class WorkflowAppConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Workflow App Config Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -26,13 +27,9 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
||||
app_id=app_model.id,
|
||||
app_mode=app_mode,
|
||||
workflow_id=workflow.id,
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(
|
||||
config=features_dict
|
||||
),
|
||||
variables=WorkflowVariablesConfigManager.convert(
|
||||
workflow=workflow
|
||||
),
|
||||
additional_features=cls.convert_features(features_dict, app_mode)
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=features_dict),
|
||||
variables=WorkflowVariablesConfigManager.convert(workflow=workflow),
|
||||
additional_features=cls.convert_features(features_dict, app_mode),
|
||||
)
|
||||
|
||||
return app_config
|
||||
@@ -50,8 +47,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config,
|
||||
is_vision=False
|
||||
config=config, is_vision=False
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
@@ -61,9 +57,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
only_structure_validate=only_structure_validate
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -34,32 +34,40 @@ logger = logging.getLogger(__name__)
|
||||
class WorkflowAppGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -71,27 +79,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
inputs = args["inputs"]
|
||||
|
||||
# parse files
|
||||
files = args['files'] if args.get('files') else []
|
||||
files = args["files"] if args.get("files") else []
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config:
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
file_extra_config,
|
||||
user
|
||||
)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
@@ -107,7 +109,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
stream=stream,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
trace_manager=trace_manager
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@@ -118,16 +120,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self, app_model: App,
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[dict, Generator[str, None, None]]:
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -137,22 +143,27 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=app_model.mode
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
})
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"queue_manager": queue_manager,
|
||||
"context": contextvars.copy_context(),
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
@@ -165,17 +176,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(
|
||||
response=response,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(self, app_model: App,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account,
|
||||
args: dict,
|
||||
stream: bool = True):
|
||||
def single_iteration_generate(
|
||||
self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True
|
||||
) -> dict[str, Any] | Generator[str, Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -187,20 +192,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param stream: is stream
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError('node_id is required')
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get('inputs') is None:
|
||||
raise ValueError('inputs is required')
|
||||
|
||||
extras = {
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
app_config = WorkflowAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
@@ -211,11 +209,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras=extras,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity(
|
||||
node_id=node_id,
|
||||
inputs=args['inputs']
|
||||
)
|
||||
node_id=node_id, inputs=args["inputs"]
|
||||
),
|
||||
)
|
||||
contexts.tenant_id.set(application_generate_entity.app_config.tenant_id)
|
||||
|
||||
@@ -225,18 +222,23 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
@@ -244,50 +246,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner()
|
||||
if application_generate_entity.single_iteration_run:
|
||||
single_iteration_run = application_generate_entity.single_iteration_run
|
||||
runner.single_iteration_run(
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
workflow_id=application_generate_entity.app_config.workflow_id,
|
||||
queue_manager=queue_manager,
|
||||
inputs=single_iteration_run.inputs,
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_id=application_generate_entity.user_id
|
||||
)
|
||||
else:
|
||||
runner.run(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError('Incorrect API key provided'),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except (ValueError, InvokeError) as e:
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
|
||||
if os.environ.get("DEBUG") and os.environ.get("DEBUG", "false").lower() == "true":
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.remove()
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False) -> Union[
|
||||
WorkflowAppBlockingResponse,
|
||||
Generator[WorkflowAppStreamResponse, None, None]
|
||||
]:
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -303,7 +295,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -12,10 +12,7 @@ from core.app.entities.queue_entities import (
|
||||
|
||||
|
||||
class WorkflowAppQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
app_mode: str) -> None:
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
@@ -27,19 +24,18 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(
|
||||
task_id=self._task_id,
|
||||
app_mode=self._app_mode,
|
||||
event=event
|
||||
)
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(event, QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent):
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
|
||||
@@ -4,129 +4,125 @@ from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
from core.app.apps.workflow.workflow_event_trigger_callback import WorkflowEventTriggerCallback
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowAppRunner:
|
||||
class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Workflow Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == application_generate_entity.user_id).first()
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = application_generate_entity.user_id
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
raise ValueError("App not found")
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
files = application_generate_entity.files
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = [
|
||||
WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)
|
||||
]
|
||||
|
||||
if bool(os.environ.get('DEBUG', 'False').lower() == 'true'):
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(graph_config=workflow.graph_dict)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.run_workflow(
|
||||
workflow=workflow,
|
||||
user_id=application_generate_entity.user_id,
|
||||
user_from=UserFrom.ACCOUNT
|
||||
if application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
callbacks=workflow_callbacks,
|
||||
call_depth=application_generate_entity.call_depth,
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER]
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
def single_iteration_run(
|
||||
self, app_id: str, workflow_id: str, queue_manager: AppQueueManager, inputs: dict, node_id: str, user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError('App not found')
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
if not app_record.workflow_id:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow = self.get_workflow(app_model=app_record, workflow_id=workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
|
||||
workflow_callbacks = [WorkflowEventTriggerCallback(queue_manager=queue_manager, workflow=workflow)]
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
workflow_engine_manager.single_step_run_iteration_workflow_node(
|
||||
workflow=workflow, node_id=node_id, user_id=user_id, user_inputs=inputs, callbacks=workflow_callbacks
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
@@ -35,8 +35,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
@@ -47,12 +48,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
@@ -63,8 +64,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
yield json.dumps(response_chunk)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]) \
|
||||
-> Generator[str, None, None]:
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[WorkflowAppStreamResponse, None, None]
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
@@ -75,12 +77,12 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield 'ping'
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
'event': sub_stream_response.event.value,
|
||||
'workflow_run_id': chunk.workflow_run_id,
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -15,10 +16,12 @@ from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
@@ -32,19 +35,16 @@ from core.app.entities.task_entities import (
|
||||
MessageAudioStreamResponse,
|
||||
StreamResponse,
|
||||
TextChunkStreamResponse,
|
||||
TextReplaceStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStreamGenerateNodes,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@@ -52,8 +52,8 @@ from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,18 +63,21 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
"""
|
||||
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_application_generate_entity: WorkflowAppGenerateEntity
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
_iteration_nested_relations: dict[str, list[str]]
|
||||
|
||||
def __init__(self, application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -93,14 +96,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
self._workflow = workflow
|
||||
self._workflow_system_variables = {
|
||||
SystemVariableKey.FILES: application_generate_entity.files,
|
||||
SystemVariableKey.USER_ID: user_id
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
}
|
||||
|
||||
self._task_state = WorkflowTaskState(
|
||||
iteration_nested_node_ids=[]
|
||||
)
|
||||
self._stream_generate_nodes = self._get_stream_generate_nodes()
|
||||
self._iteration_nested_relations = self._get_iteration_nested_relations(self._workflow.graph_dict)
|
||||
self._task_state = WorkflowTaskState()
|
||||
|
||||
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
@@ -111,16 +110,13 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
db.session.refresh(self._user)
|
||||
db.session.close()
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> WorkflowAppBlockingResponse:
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> WorkflowAppBlockingResponse:
|
||||
"""
|
||||
To blocking response.
|
||||
:return:
|
||||
@@ -129,43 +125,42 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_run_id=stream_response.data.id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp())
|
||||
)
|
||||
id=stream_response.data.id,
|
||||
workflow_id=stream_response.data.workflow_id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs,
|
||||
error=stream_response.data.error,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
workflow_run_id = None
|
||||
for stream_response in generator:
|
||||
yield WorkflowAppStreamResponse(
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
stream_response=stream_response
|
||||
)
|
||||
if isinstance(stream_response, WorkflowStartStreamResponse):
|
||||
workflow_run_id = stream_response.workflow_run_id
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
@@ -175,20 +170,24 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
publisher = None
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tts_publisher = None
|
||||
task_id = self._application_generate_entity.task_id
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
features_dict = self._workflow.features_dict
|
||||
|
||||
if features_dict.get('text_to_speech') and features_dict['text_to_speech'].get('enabled') and features_dict[
|
||||
'text_to_speech'].get('autoPlay') == 'enabled':
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, features_dict['text_to_speech'].get('voice'))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
if (
|
||||
features_dict.get("text_to_speech")
|
||||
and features_dict["text_to_speech"].get("enabled")
|
||||
and features_dict["text_to_speech"].get("autoPlay") == "enabled"
|
||||
):
|
||||
tts_publisher = AppGeneratorTTSPublisher(tenant_id, features_dict["text_to_speech"].get("voice"))
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id=task_id)
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@@ -198,9 +197,9 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
start_listener_time = time.time()
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
try:
|
||||
if not publisher:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
@@ -213,105 +212,176 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
if publisher:
|
||||
publisher.publish(message=message)
|
||||
event = message.event
|
||||
graph_runtime_state = None
|
||||
workflow_run = None
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
elif isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueWorkflowStartedEvent):
|
||||
workflow_run = self._handle_workflow_start()
|
||||
# override graph runtime state
|
||||
graph_runtime_state = event.graph_runtime_state
|
||||
|
||||
# init workflow run
|
||||
workflow_run = self._handle_workflow_run_start()
|
||||
yield self._workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueNodeStartedEvent):
|
||||
workflow_node_execution = self._handle_node_start(event)
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
# search stream_generate_routes if node id is answer start at node
|
||||
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
|
||||
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
|
||||
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
|
||||
|
||||
# generate stream outputs when node started
|
||||
yield from self._generate_stream_outputs_when_node_started()
|
||||
|
||||
yield self._workflow_node_start_to_stream_response(
|
||||
response = self._workflow_node_start_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_node_finished(event)
|
||||
|
||||
yield self._workflow_node_finish_to_stream_response(
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if isinstance(event, QueueNodeFailedEvent):
|
||||
yield from self._handle_iteration_exception(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
error=f'Child node failed: {event.error}'
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent):
|
||||
if isinstance(event, QueueIterationNextEvent):
|
||||
# clear ran node execution infos of current iteration
|
||||
iteration_relations = self._iteration_nested_relations.get(event.node_id)
|
||||
if iteration_relations:
|
||||
for node_id in iteration_relations:
|
||||
self._task_state.ran_node_execution_infos.pop(node_id, None)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
yield self._handle_iteration_to_stream_response(self._application_generate_entity.task_id, event)
|
||||
self._handle_iteration_operation(event)
|
||||
elif isinstance(event, QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent):
|
||||
workflow_run = self._handle_workflow_finished(
|
||||
event, trace_manager=trace_manager
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_parallel_branch_finished_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationStartEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_next_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
yield self._workflow_iteration_completed_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowSucceededEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
outputs=json.dumps(event.outputs)
|
||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
||||
else None,
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run=workflow_run
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
if not graph_runtime_state:
|
||||
raise Exception("Graph runtime state not initialized.")
|
||||
|
||||
workflow_run = self._handle_workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=graph_runtime_state.start_at,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
status=WorkflowRunStatus.FAILED
|
||||
if isinstance(event, QueueWorkflowFailedEvent)
|
||||
else WorkflowRunStatus.STOPPED,
|
||||
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
|
||||
conversation_id=None,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# save workflow app log
|
||||
self._save_workflow_app_log(workflow_run)
|
||||
|
||||
yield self._workflow_finish_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
|
||||
)
|
||||
elif isinstance(event, QueueTextChunkEvent):
|
||||
delta_text = event.text
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._is_stream_out_support(
|
||||
event=event
|
||||
):
|
||||
continue
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(delta_text)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._text_replace_to_stream_response(event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
if publisher:
|
||||
publisher.publish(None)
|
||||
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(None)
|
||||
|
||||
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None:
|
||||
"""
|
||||
@@ -329,15 +399,15 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
# not save log for debugging
|
||||
return
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
created_from=created_from.value,
|
||||
created_by_role=('account' if isinstance(self._user, Account) else 'end_user'),
|
||||
created_by=self._user.id,
|
||||
)
|
||||
workflow_app_log = WorkflowAppLog()
|
||||
workflow_app_log.tenant_id = workflow_run.tenant_id
|
||||
workflow_app_log.app_id = workflow_run.app_id
|
||||
workflow_app_log.workflow_id = workflow_run.workflow_id
|
||||
workflow_app_log.workflow_run_id = workflow_run.id
|
||||
workflow_app_log.created_from = created_from.value
|
||||
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user"
|
||||
workflow_app_log.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
@@ -349,185 +419,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
:return:
|
||||
"""
|
||||
response = TextChunkStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=TextChunkStreamResponse.Data(text=text)
|
||||
task_id=self._application_generate_entity.task_id, data=TextChunkStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _text_replace_to_stream_response(self, text: str) -> TextReplaceStreamResponse:
|
||||
"""
|
||||
Text replace to stream response.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
return TextReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
text=TextReplaceStreamResponse.Data(text=text)
|
||||
)
|
||||
|
||||
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
|
||||
"""
|
||||
Get stream generate nodes.
|
||||
:return:
|
||||
"""
|
||||
# find all answer nodes
|
||||
graph = self._workflow.graph_dict
|
||||
end_node_configs = [
|
||||
node for node in graph['nodes']
|
||||
if node.get('data', {}).get('type') == NodeType.END.value
|
||||
]
|
||||
|
||||
# parse stream output node value selectors of end nodes
|
||||
stream_generate_routes = {}
|
||||
for node_config in end_node_configs:
|
||||
# get generate route for stream output
|
||||
end_node_id = node_config['id']
|
||||
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
|
||||
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
|
||||
if not start_node_ids:
|
||||
continue
|
||||
|
||||
for start_node_id in start_node_ids:
|
||||
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
|
||||
end_node_id=end_node_id,
|
||||
stream_node_ids=generate_nodes
|
||||
)
|
||||
|
||||
return stream_generate_routes
|
||||
|
||||
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
|
||||
-> list[str]:
|
||||
"""
|
||||
Get end start at node id.
|
||||
:param graph: graph
|
||||
:param target_node_id: target node ID
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
edges = graph.get('edges')
|
||||
|
||||
# fetch all ingoing edges from source node
|
||||
ingoing_edges = []
|
||||
for edge in edges:
|
||||
if edge.get('target') == target_node_id:
|
||||
ingoing_edges.append(edge)
|
||||
|
||||
if not ingoing_edges:
|
||||
return []
|
||||
|
||||
start_node_ids = []
|
||||
for ingoing_edge in ingoing_edges:
|
||||
source_node_id = ingoing_edge.get('source')
|
||||
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
|
||||
if not source_node:
|
||||
continue
|
||||
|
||||
node_type = source_node.get('data', {}).get('type')
|
||||
node_iteration_id = source_node.get('data', {}).get('iteration_id')
|
||||
iteration_start_node_id = None
|
||||
if node_iteration_id:
|
||||
iteration_node = next((node for node in nodes if node.get('id') == node_iteration_id), None)
|
||||
iteration_start_node_id = iteration_node.get('data', {}).get('start_node_id')
|
||||
|
||||
if node_type in [
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER.value
|
||||
]:
|
||||
start_node_id = target_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
elif node_type == NodeType.START.value or \
|
||||
node_iteration_id is not None and iteration_start_node_id == source_node.get('id'):
|
||||
start_node_id = source_node_id
|
||||
start_node_ids.append(start_node_id)
|
||||
else:
|
||||
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
|
||||
if sub_start_node_ids:
|
||||
start_node_ids.extend(sub_start_node_ids)
|
||||
|
||||
return start_node_ids
|
||||
|
||||
def _generate_stream_outputs_when_node_started(self) -> Generator:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:return:
|
||||
"""
|
||||
if self._task_state.current_stream_generate_state:
|
||||
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
|
||||
|
||||
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
|
||||
if node_id not in stream_node_ids:
|
||||
continue
|
||||
|
||||
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
|
||||
|
||||
# get chunk node execution
|
||||
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
|
||||
|
||||
if not route_chunk_node_execution:
|
||||
continue
|
||||
|
||||
outputs = route_chunk_node_execution.outputs_dict
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# get value from outputs
|
||||
text = outputs.get('text')
|
||||
|
||||
if text:
|
||||
self._task_state.answer += text
|
||||
yield self._text_chunk_to_stream_response(text)
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.metadata:
|
||||
return False
|
||||
|
||||
if 'node_id' not in event.metadata:
|
||||
return False
|
||||
|
||||
node_id = event.metadata.get('node_id')
|
||||
node_type = event.metadata.get('node_type')
|
||||
stream_output_value_selector = event.metadata.get('value_selector')
|
||||
if not stream_output_value_selector:
|
||||
return False
|
||||
|
||||
if not self._task_state.current_stream_generate_state:
|
||||
return False
|
||||
|
||||
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
|
||||
return False
|
||||
|
||||
if node_type != NodeType.LLM:
|
||||
# only LLM support chunk stream output
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get iteration nested relations.
|
||||
:param graph: graph
|
||||
:return:
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
|
||||
iteration_ids = [node.get('id') for node in nodes
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]]
|
||||
|
||||
return {
|
||||
iteration_id: [
|
||||
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
|
||||
] for iteration_id in iteration_ids
|
||||
}
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
||||
371
api/core/app/apps/workflow_app_runner.py
Normal file
371
api/core/app/apps/workflow_app_runner.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id or node.get("data", {}).get("iteration_id", "") == node_id
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id=node_id)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError("iteration node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get("data", {}).get("type"))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get("data", {})),
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
@@ -1,10 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
@@ -16,138 +30,184 @@ _TEXT_COLOR_MAPPING = {
|
||||
|
||||
|
||||
class WorkflowLoggingCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.current_node_id = None
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_started]", color='pink')
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self.print_text("\n[GraphRunStartedEvent]", color="pink")
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self.print_text("\n[GraphRunSucceededEvent]", color="green")
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self.on_workflow_node_execute_started(event=event)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self.on_workflow_node_execute_succeeded(event=event)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self.on_workflow_node_execute_failed(event=event)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self.on_node_text_chunk(event=event)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self.on_workflow_parallel_started(event=event)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent):
|
||||
self.on_workflow_parallel_completed(event=event)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self.on_workflow_iteration_started(event=event)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self.on_workflow_iteration_next(event=event)
|
||||
elif isinstance(event, IterationRunSucceededEvent | IterationRunFailedEvent):
|
||||
self.on_workflow_iteration_completed(event=event)
|
||||
else:
|
||||
self.print_text(f"\n[{event.__class__.__name__}]", color="blue")
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_succeeded]", color='green')
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_run_failed]", color='red')
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
def on_workflow_node_execute_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_started]", color='yellow')
|
||||
self.print_text(f"Node ID: {node_id}", color='yellow')
|
||||
self.print_text(f"Type: {node_type.value}", color='yellow')
|
||||
self.print_text(f"Index: {node_run_index}", color='yellow')
|
||||
if predecessor_node_id:
|
||||
self.print_text(f"Predecessor Node ID: {predecessor_node_id}", color='yellow')
|
||||
self.print_text("\n[NodeRunStartedEvent]", color="yellow")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="yellow")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="yellow")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="yellow")
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
def on_workflow_node_execute_succeeded(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_succeeded]", color='green')
|
||||
self.print_text(f"Node ID: {node_id}", color='green')
|
||||
self.print_text(f"Type: {node_type.value}", color='green')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='green')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='green')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='green')
|
||||
self.print_text(f"Metadata: {jsonable_encoder(execution_metadata) if execution_metadata else ''}",
|
||||
color='green')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunSucceededEvent]", color="green")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="green")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="green")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="green")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="green"
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}",
|
||||
color="green",
|
||||
)
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}",
|
||||
color="green",
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_node_execute_failed]", color='red')
|
||||
self.print_text(f"Node ID: {node_id}", color='red')
|
||||
self.print_text(f"Type: {node_type.value}", color='red')
|
||||
self.print_text(f"Error: {error}", color='red')
|
||||
self.print_text(f"Inputs: {jsonable_encoder(inputs) if inputs else ''}", color='red')
|
||||
self.print_text(f"Process Data: {jsonable_encoder(process_data) if process_data else ''}", color='red')
|
||||
self.print_text(f"Outputs: {jsonable_encoder(outputs) if outputs else ''}", color='red')
|
||||
route_node_state = event.route_node_state
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
self.print_text("\n[NodeRunFailedEvent]", color="red")
|
||||
self.print_text(f"Node ID: {event.node_id}", color="red")
|
||||
self.print_text(f"Node Title: {event.node_data.title}", color="red")
|
||||
self.print_text(f"Type: {event.node_type.value}", color="red")
|
||||
|
||||
if route_node_state.node_run_result:
|
||||
node_run_result = route_node_state.node_run_result
|
||||
self.print_text(f"Error: {node_run_result.error}", color="red")
|
||||
self.print_text(
|
||||
f"Inputs: {jsonable_encoder(node_run_result.inputs) if node_run_result.inputs else ''}", color="red"
|
||||
)
|
||||
self.print_text(
|
||||
f"Process Data: {jsonable_encoder(node_run_result.process_data) if node_run_result.process_data else ''}",
|
||||
color="red",
|
||||
)
|
||||
self.print_text(
|
||||
f"Outputs: {jsonable_encoder(node_run_result.outputs) if node_run_result.outputs else ''}", color="red"
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
if not self.current_node_id or self.current_node_id != node_id:
|
||||
self.current_node_id = node_id
|
||||
self.print_text('\n[on_node_text_chunk]')
|
||||
self.print_text(f"Node ID: {node_id}")
|
||||
self.print_text(f"Metadata: {jsonable_encoder(metadata) if metadata else ''}")
|
||||
route_node_state = event.route_node_state
|
||||
if not self.current_node_id or self.current_node_id != route_node_state.node_id:
|
||||
self.current_node_id = route_node_state.node_id
|
||||
self.print_text("\n[NodeRunStreamChunkEvent]")
|
||||
self.print_text(f"Node ID: {route_node_state.node_id}")
|
||||
|
||||
self.print_text(text, color="pink", end="")
|
||||
node_run_result = route_node_state.node_run_result
|
||||
if node_run_result:
|
||||
self.print_text(
|
||||
f"Metadata: {jsonable_encoder(node_run_result.metadata) if node_run_result.metadata else ''}"
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
self.print_text(event.chunk_content, color="pink", end="")
|
||||
|
||||
def on_workflow_parallel_started(self, event: ParallelBranchRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish parallel started
|
||||
"""
|
||||
self.print_text("\n[ParallelBranchRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color="blue")
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color="blue")
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_parallel_completed(
|
||||
self, event: ParallelBranchRunSucceededEvent | ParallelBranchRunFailedEvent
|
||||
) -> None:
|
||||
"""
|
||||
Publish parallel completed
|
||||
"""
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
color = "blue"
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
color = "red"
|
||||
|
||||
self.print_text(
|
||||
"\n[ParallelBranchRunSucceededEvent]"
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent)
|
||||
else "\n[ParallelBranchRunFailedEvent]",
|
||||
color=color,
|
||||
)
|
||||
self.print_text(f"Parallel ID: {event.parallel_id}", color=color)
|
||||
self.print_text(f"Branch ID: {event.parallel_start_node_id}", color=color)
|
||||
if event.in_iteration_id:
|
||||
self.print_text(f"Iteration ID: {event.in_iteration_id}", color=color)
|
||||
|
||||
if isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self.print_text(f"Error: {event.error}", color=color)
|
||||
|
||||
def on_workflow_iteration_started(self, event: IterationRunStartedEvent) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_started]", color='blue')
|
||||
self.print_text(f"Node ID: {node_id}", color='blue')
|
||||
self.print_text("\n[IterationRunStartedEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[dict]) -> None:
|
||||
def on_workflow_iteration_next(self, event: IterationRunNextEvent) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_next]", color='blue')
|
||||
self.print_text("\n[IterationRunNextEvent]", color="blue")
|
||||
self.print_text(f"Iteration Node ID: {event.iteration_id}", color="blue")
|
||||
self.print_text(f"Iteration Index: {event.index}", color="blue")
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
def on_workflow_iteration_completed(self, event: IterationRunSucceededEvent | IterationRunFailedEvent) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self.print_text("\n[on_workflow_iteration_completed]", color='blue')
|
||||
self.print_text(
|
||||
"\n[IterationRunSucceededEvent]"
|
||||
if isinstance(event, IterationRunSucceededEvent)
|
||||
else "\n[IterationRunFailedEvent]",
|
||||
color="blue",
|
||||
)
|
||||
self.print_text(f"Node ID: {event.iteration_id}", color="blue")
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self.print_text("\n[on_workflow_event]", color='blue')
|
||||
self.print_text(f"Event: {jsonable_encoder(event)}", color='blue')
|
||||
|
||||
def print_text(
|
||||
self, text: str, color: Optional[str] = None, end: str = "\n"
|
||||
) -> None:
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(f'{text_to_print}', end=end)
|
||||
print(f"{text_to_print}", end=end)
|
||||
|
||||
def _get_colored_text(self, text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
|
||||
@@ -15,13 +15,14 @@ class InvokeFrom(Enum):
|
||||
"""
|
||||
Invoke From.
|
||||
"""
|
||||
SERVICE_API = 'service-api'
|
||||
WEB_APP = 'web-app'
|
||||
EXPLORE = 'explore'
|
||||
DEBUGGER = 'debugger'
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'InvokeFrom':
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -31,7 +32,7 @@ class InvokeFrom(Enum):
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid invoke from value {value}')
|
||||
raise ValueError(f"invalid invoke from value {value}")
|
||||
|
||||
def to_source(self) -> str:
|
||||
"""
|
||||
@@ -40,21 +41,22 @@ class InvokeFrom(Enum):
|
||||
:return: source
|
||||
"""
|
||||
if self == InvokeFrom.WEB_APP:
|
||||
return 'web_app'
|
||||
return "web_app"
|
||||
elif self == InvokeFrom.DEBUGGER:
|
||||
return 'dev'
|
||||
return "dev"
|
||||
elif self == InvokeFrom.EXPLORE:
|
||||
return 'explore_app'
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return 'api'
|
||||
return "api"
|
||||
|
||||
return 'dev'
|
||||
return "dev"
|
||||
|
||||
|
||||
class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
"""
|
||||
Model Config With Credentials Entity.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
model: str
|
||||
model_schema: AIModelEntity
|
||||
@@ -72,6 +74,7 @@ class AppGenerateEntity(BaseModel):
|
||||
"""
|
||||
App Generate Entity.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
|
||||
# app config
|
||||
@@ -102,6 +105,7 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: EasyUIBasedAppConfig
|
||||
model_conf: ModelConfigWithCredentialsEntity
|
||||
@@ -116,6 +120,7 @@ class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
@@ -123,6 +128,7 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Completion Application Generate Entity.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -130,6 +136,7 @@ class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||
"""
|
||||
Agent Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
|
||||
@@ -137,6 +144,7 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Advanced Chat Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
@@ -147,15 +155,18 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
|
||||
class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Workflow Application Generate Entity.
|
||||
"""
|
||||
|
||||
# app config
|
||||
app_config: WorkflowUIBasedAppConfig
|
||||
|
||||
@@ -163,6 +174,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -5,13 +6,15 @@ from pydantic import BaseModel, field_validator
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
class QueueEvent(str, Enum):
|
||||
"""
|
||||
QueueEvent enum
|
||||
"""
|
||||
|
||||
LLM_CHUNK = "llm_chunk"
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
@@ -31,6 +34,9 @@ class QueueEvent(str, Enum):
|
||||
ANNOTATION_REPLY = "annotation_reply"
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
MESSAGE_FILE = "message_file"
|
||||
PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
|
||||
PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
|
||||
PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
@@ -38,46 +44,73 @@ class QueueEvent(str, Enum):
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
"""
|
||||
QueueEvent entity
|
||||
QueueEvent abstract entity
|
||||
"""
|
||||
|
||||
event: QueueEvent
|
||||
|
||||
|
||||
class QueueLLMChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueLLMChunkEvent entity
|
||||
Only for basic mode apps
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.LLM_CHUNK
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
class QueueIterationStartEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationStartEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_START
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
inputs: dict = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueIterationNextEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationNextEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_NEXT
|
||||
|
||||
index: int
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
@field_validator('output', mode='before')
|
||||
@field_validator("output", mode="before")
|
||||
@classmethod
|
||||
def set_output(cls, v):
|
||||
"""
|
||||
@@ -87,41 +120,66 @@ class QueueIterationNextEvent(AppQueueEvent):
|
||||
return None
|
||||
if isinstance(v, int | float | str | bool | dict | list):
|
||||
return v
|
||||
raise ValueError('output must be a valid type')
|
||||
raise ValueError("output must be a valid type")
|
||||
|
||||
|
||||
class QueueIterationCompletedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueIterationCompletedEvent entity
|
||||
"""
|
||||
event:QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
event: QueueEvent = QueueEvent.ITERATION_COMPLETED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
start_at: datetime
|
||||
|
||||
node_run_index: int
|
||||
outputs: dict
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueTextChunkEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.TEXT_CHUNK
|
||||
text: str
|
||||
metadata: Optional[dict] = None
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_MESSAGE
|
||||
chunk: LLMResultChunk
|
||||
|
||||
|
||||
|
||||
class QueueMessageReplaceEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageReplaceEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
|
||||
text: str
|
||||
|
||||
@@ -130,14 +188,18 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueRetrieverResourcesEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES
|
||||
retriever_resources: list[dict]
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAnnotationReplyEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ANNOTATION_REPLY
|
||||
message_annotation_id: str
|
||||
|
||||
@@ -146,6 +208,7 @@ class QueueMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueMessageEndEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_END
|
||||
llm_result: Optional[LLMResult] = None
|
||||
|
||||
@@ -154,6 +217,7 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAdvancedChatMessageEndEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ADVANCED_CHAT_MESSAGE_END
|
||||
|
||||
|
||||
@@ -161,20 +225,25 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
graph_runtime_state: GraphRuntimeState
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_SUCCEEDED
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class QueueWorkflowFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
|
||||
error: str
|
||||
|
||||
@@ -183,29 +252,55 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_STARTED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
node_run_index: int = 1
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_SUCCEEDED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
execution_metadata: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: Optional[str] = None
|
||||
|
||||
@@ -214,15 +309,28 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
@@ -231,6 +339,7 @@ class QueueAgentThoughtEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.AGENT_THOUGHT
|
||||
agent_thought_id: str
|
||||
|
||||
@@ -239,6 +348,7 @@ class QueueMessageFileEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueAgentThoughtEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.MESSAGE_FILE
|
||||
message_file_id: str
|
||||
|
||||
@@ -247,6 +357,7 @@ class QueueErrorEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueErrorEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.ERROR
|
||||
error: Any = None
|
||||
|
||||
@@ -255,6 +366,7 @@ class QueuePingEvent(AppQueueEvent):
|
||||
"""
|
||||
QueuePingEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PING
|
||||
|
||||
|
||||
@@ -262,10 +374,12 @@ class QueueStopEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueStopEvent entity
|
||||
"""
|
||||
|
||||
class StopBy(Enum):
|
||||
"""
|
||||
Stop by enum
|
||||
"""
|
||||
|
||||
USER_MANUAL = "user-manual"
|
||||
ANNOTATION_REPLY = "annotation-reply"
|
||||
OUTPUT_MODERATION = "output-moderation"
|
||||
@@ -274,11 +388,25 @@ class QueueStopEvent(AppQueueEvent):
|
||||
event: QueueEvent = QueueEvent.STOP
|
||||
stopped_by: StopBy
|
||||
|
||||
def get_stop_reason(self) -> str:
|
||||
"""
|
||||
To stop reason
|
||||
"""
|
||||
reason_mapping = {
|
||||
QueueStopEvent.StopBy.USER_MANUAL: "Stopped by user.",
|
||||
QueueStopEvent.StopBy.ANNOTATION_REPLY: "Stopped by annotation reply.",
|
||||
QueueStopEvent.StopBy.OUTPUT_MODERATION: "Stopped by output moderation.",
|
||||
QueueStopEvent.StopBy.INPUT_MODERATION: "Stopped by input moderation.",
|
||||
}
|
||||
|
||||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage entity
|
||||
QueueMessage abstract entity
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
app_mode: str
|
||||
event: AppQueueEvent
|
||||
@@ -288,6 +416,7 @@ class MessageQueueMessage(QueueMessage):
|
||||
"""
|
||||
MessageQueueMessage entity
|
||||
"""
|
||||
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
|
||||
@@ -296,4 +425,57 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
WorkflowQueueMessage entity
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueParallelBranchRunStartedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunStartedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_STARTED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunSucceededEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunSucceededEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_SUCCEEDED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class QueueParallelBranchRunFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueParallelBranchRunFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PARALLEL_BRANCH_RUN_FAILED
|
||||
|
||||
parallel_id: str
|
||||
parallel_start_node_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
error: str
|
||||
|
||||
@@ -3,44 +3,16 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class WorkflowStreamGenerateNodes(BaseModel):
|
||||
"""
|
||||
WorkflowStreamGenerateNodes entity
|
||||
"""
|
||||
end_node_id: str
|
||||
stream_node_ids: list[str]
|
||||
|
||||
|
||||
class ChatflowStreamGenerateRoute(BaseModel):
|
||||
"""
|
||||
ChatflowStreamGenerateRoute entity
|
||||
"""
|
||||
answer_node_id: str
|
||||
generate_route: list[GenerateRouteChunk]
|
||||
current_route_position: int = 0
|
||||
|
||||
|
||||
class NodeExecutionInfo(BaseModel):
|
||||
"""
|
||||
NodeExecutionInfo entity
|
||||
"""
|
||||
workflow_node_execution_id: str
|
||||
node_type: NodeType
|
||||
start_at: float
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
@@ -48,6 +20,7 @@ class EasyUITaskState(TaskState):
|
||||
"""
|
||||
EasyUITaskState entity
|
||||
"""
|
||||
|
||||
llm_result: LLMResult
|
||||
|
||||
|
||||
@@ -55,34 +28,15 @@ class WorkflowTaskState(TaskState):
|
||||
"""
|
||||
WorkflowTaskState entity
|
||||
"""
|
||||
|
||||
answer: str = ""
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
start_at: Optional[float] = None
|
||||
total_tokens: int = 0
|
||||
total_steps: int = 0
|
||||
|
||||
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
|
||||
latest_node_execution_info: Optional[NodeExecutionInfo] = None
|
||||
|
||||
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
|
||||
|
||||
iteration_nested_node_ids: list[str] = None
|
||||
|
||||
|
||||
class AdvancedChatTaskState(WorkflowTaskState):
|
||||
"""
|
||||
AdvancedChatTaskState entity
|
||||
"""
|
||||
usage: LLMUsage
|
||||
|
||||
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
|
||||
|
||||
|
||||
class StreamEvent(Enum):
|
||||
"""
|
||||
Stream event
|
||||
"""
|
||||
|
||||
PING = "ping"
|
||||
ERROR = "error"
|
||||
MESSAGE = "message"
|
||||
@@ -97,6 +51,8 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
ITERATION_NEXT = "iteration_next"
|
||||
ITERATION_COMPLETED = "iteration_completed"
|
||||
@@ -108,6 +64,7 @@ class StreamResponse(BaseModel):
|
||||
"""
|
||||
StreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent
|
||||
task_id: str
|
||||
|
||||
@@ -119,6 +76,7 @@ class ErrorStreamResponse(StreamResponse):
|
||||
"""
|
||||
ErrorStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.ERROR
|
||||
err: Exception
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -128,6 +86,7 @@ class MessageStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
@@ -137,6 +96,7 @@ class MessageAudioStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE
|
||||
audio: str
|
||||
|
||||
@@ -145,6 +105,7 @@ class MessageAudioEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.TTS_MESSAGE_END
|
||||
audio: str
|
||||
|
||||
@@ -153,6 +114,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageEndStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||
id: str
|
||||
metadata: dict = {}
|
||||
@@ -162,6 +124,7 @@ class MessageFileStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageFileStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_FILE
|
||||
id: str
|
||||
type: str
|
||||
@@ -173,6 +136,7 @@ class MessageReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
MessageReplaceStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
|
||||
answer: str
|
||||
|
||||
@@ -181,6 +145,7 @@ class AgentThoughtStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentThoughtStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_THOUGHT
|
||||
id: str
|
||||
position: int
|
||||
@@ -196,6 +161,7 @@ class AgentMessageStreamResponse(StreamResponse):
|
||||
"""
|
||||
AgentMessageStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.AGENT_MESSAGE
|
||||
id: str
|
||||
answer: str
|
||||
@@ -210,6 +176,7 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
@@ -230,6 +197,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
sequence_number: int
|
||||
@@ -258,6 +226,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
@@ -267,6 +236,11 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
inputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -286,8 +260,13 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"created_at": self.data.created_at,
|
||||
"extras": {}
|
||||
}
|
||||
"extras": {},
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -300,6 +279,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
@@ -316,6 +296,11 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[list[dict]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@@ -342,11 +327,62 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": []
|
||||
}
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_STARTED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class ParallelBranchFinishedStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchFinishedStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
parallel_id: str
|
||||
parallel_branch_id: str
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
created_at: int
|
||||
|
||||
event: StreamEvent = StreamEvent.PARALLEL_BRANCH_FINISHED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@@ -356,6 +392,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
@@ -364,6 +401,8 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
extras: dict = {}
|
||||
metadata: dict = {}
|
||||
inputs: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -379,6 +418,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
@@ -387,6 +427,8 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
created_at: int
|
||||
pre_iteration_output: Optional[Any] = None
|
||||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@@ -402,14 +444,15 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[dict] = None
|
||||
created_at: int
|
||||
extras: dict = None
|
||||
inputs: dict = None
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[dict] = None
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
@@ -417,6 +460,8 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
execution_metadata: Optional[dict] = None
|
||||
finished_at: int
|
||||
steps: int
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_COMPLETED
|
||||
workflow_run_id: str
|
||||
@@ -432,6 +477,7 @@ class TextChunkStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_CHUNK
|
||||
@@ -447,6 +493,7 @@ class TextReplaceStreamResponse(StreamResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.TEXT_REPLACE
|
||||
@@ -457,6 +504,7 @@ class PingStreamResponse(StreamResponse):
|
||||
"""
|
||||
PingStreamResponse entity
|
||||
"""
|
||||
|
||||
event: StreamEvent = StreamEvent.PING
|
||||
|
||||
|
||||
@@ -464,6 +512,7 @@ class AppStreamResponse(BaseModel):
|
||||
"""
|
||||
AppStreamResponse entity
|
||||
"""
|
||||
|
||||
stream_response: StreamResponse
|
||||
|
||||
|
||||
@@ -471,6 +520,7 @@ class ChatbotAppStreamResponse(AppStreamResponse):
|
||||
"""
|
||||
ChatbotAppStreamResponse entity
|
||||
"""
|
||||
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int
|
||||
@@ -480,6 +530,7 @@ class CompletionAppStreamResponse(AppStreamResponse):
|
||||
"""
|
||||
CompletionAppStreamResponse entity
|
||||
"""
|
||||
|
||||
message_id: str
|
||||
created_at: int
|
||||
|
||||
@@ -488,13 +539,15 @@ class WorkflowAppStreamResponse(AppStreamResponse):
|
||||
"""
|
||||
WorkflowAppStreamResponse entity
|
||||
"""
|
||||
workflow_run_id: str
|
||||
|
||||
workflow_run_id: Optional[str] = None
|
||||
|
||||
|
||||
class AppBlockingResponse(BaseModel):
|
||||
"""
|
||||
AppBlockingResponse entity
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
@@ -510,6 +563,7 @@ class ChatbotAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
conversation_id: str
|
||||
@@ -530,6 +584,7 @@ class CompletionAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
mode: str
|
||||
message_id: str
|
||||
@@ -549,6 +604,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
@@ -562,25 +618,3 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowIterationState(BaseModel):
|
||||
"""
|
||||
WorkflowIterationState entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
parent_iteration_id: Optional[str] = None
|
||||
iteration_id: str
|
||||
current_index: int
|
||||
iteration_steps_boundary: list[int] = None
|
||||
node_execution_id: str
|
||||
started_at: float
|
||||
inputs: dict = None
|
||||
total_tokens: int = 0
|
||||
node_data: BaseNodeData
|
||||
|
||||
current_iterations: dict[str, Data] = None
|
||||
|
||||
@@ -13,11 +13,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnotationReplyFeature:
|
||||
def query(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
def query(
|
||||
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
@@ -27,8 +25,9 @@ class AnnotationReplyFeature:
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_record.id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
@@ -41,55 +40,50 @@ class AnnotationReplyFeature:
|
||||
embedding_model_name = collection_binding_detail.model_name
|
||||
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_provider_name,
|
||||
embedding_model_name,
|
||||
'annotation'
|
||||
embedding_provider_name, embedding_model_name, "annotation"
|
||||
)
|
||||
|
||||
dataset = Dataset(
|
||||
id=app_record.id,
|
||||
tenant_id=app_record.tenant_id,
|
||||
indexing_technique='high_quality',
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider=embedding_provider_name,
|
||||
embedding_model=embedding_model_name,
|
||||
collection_binding_id=dataset_collection_binding.id
|
||||
collection_binding_id=dataset_collection_binding.id,
|
||||
)
|
||||
|
||||
vector = Vector(dataset, attributes=['doc_id', 'annotation_id', 'app_id'])
|
||||
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
|
||||
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
query=query, top_k=1, score_threshold=score_threshold, filter={"group_id": [dataset.id]}
|
||||
)
|
||||
|
||||
if documents:
|
||||
annotation_id = documents[0].metadata['annotation_id']
|
||||
score = documents[0].metadata['score']
|
||||
annotation_id = documents[0].metadata["annotation_id"]
|
||||
score = documents[0].metadata["score"]
|
||||
annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
|
||||
if annotation:
|
||||
if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]:
|
||||
from_source = 'api'
|
||||
from_source = "api"
|
||||
else:
|
||||
from_source = 'console'
|
||||
from_source = "console"
|
||||
|
||||
# insert annotation history
|
||||
AppAnnotationService.add_annotation_history(annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
message.id,
|
||||
from_source,
|
||||
score)
|
||||
AppAnnotationService.add_annotation_history(
|
||||
annotation.id,
|
||||
app_record.id,
|
||||
annotation.question,
|
||||
annotation.content,
|
||||
query,
|
||||
user_id,
|
||||
message.id,
|
||||
from_source,
|
||||
score,
|
||||
)
|
||||
|
||||
return annotation
|
||||
except Exception as e:
|
||||
logger.warning(f'Query annotation failed, exception: {str(e)}.')
|
||||
logger.warning(f"Query annotation failed, exception: {str(e)}.")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -8,8 +8,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HostingModerationFeature:
|
||||
def check(self, application_generate_entity: EasyUIBasedAppGenerateEntity,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
def check(
|
||||
self, application_generate_entity: EasyUIBasedAppGenerateEntity, prompt_messages: list[PromptMessage]
|
||||
) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -23,9 +24,6 @@ class HostingModerationFeature:
|
||||
if isinstance(prompt_message.content, str):
|
||||
text += prompt_message.content + "\n"
|
||||
|
||||
moderation_result = moderation.check_moderation(
|
||||
model_config,
|
||||
text
|
||||
)
|
||||
moderation_result = moderation.check_moderation(model_config, text)
|
||||
|
||||
return moderation_result
|
||||
|
||||
@@ -19,7 +19,7 @@ class RateLimit:
|
||||
_ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes
|
||||
_instance_dict = {}
|
||||
|
||||
def __new__(cls: type['RateLimit'], client_id: str, max_active_requests: int):
|
||||
def __new__(cls: type["RateLimit"], client_id: str, max_active_requests: int):
|
||||
if client_id not in cls._instance_dict:
|
||||
instance = super().__new__(cls)
|
||||
cls._instance_dict[client_id] = instance
|
||||
@@ -27,13 +27,13 @@ class RateLimit:
|
||||
|
||||
def __init__(self, client_id: str, max_active_requests: int):
|
||||
self.max_active_requests = max_active_requests
|
||||
if hasattr(self, 'initialized'):
|
||||
if hasattr(self, "initialized"):
|
||||
return
|
||||
self.initialized = True
|
||||
self.client_id = client_id
|
||||
self.active_requests_key = self._ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.max_active_requests_key = self._MAX_ACTIVE_REQUESTS_KEY.format(client_id)
|
||||
self.last_recalculate_time = float('-inf')
|
||||
self.last_recalculate_time = float("-inf")
|
||||
self.flush_cache(use_local_value=True)
|
||||
|
||||
def flush_cache(self, use_local_value=False):
|
||||
@@ -46,7 +46,7 @@ class RateLimit:
|
||||
pipe.execute()
|
||||
else:
|
||||
with redis_client.pipeline() as pipe:
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode('utf-8'))
|
||||
self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8"))
|
||||
redis_client.expire(self.max_active_requests_key, timedelta(days=1))
|
||||
|
||||
# flush max active requests (in-transit request list)
|
||||
@@ -54,8 +54,11 @@ class RateLimit:
|
||||
return
|
||||
request_details = redis_client.hgetall(self.active_requests_key)
|
||||
redis_client.expire(self.active_requests_key, timedelta(days=1))
|
||||
timeout_requests = [k for k, v in request_details.items() if
|
||||
time.time() - float(v.decode('utf-8')) > RateLimit._REQUEST_MAX_ALIVE_TIME]
|
||||
timeout_requests = [
|
||||
k
|
||||
for k, v in request_details.items()
|
||||
if time.time() - float(v.decode("utf-8")) > RateLimit._REQUEST_MAX_ALIVE_TIME
|
||||
]
|
||||
if timeout_requests:
|
||||
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
||||
|
||||
@@ -69,8 +72,10 @@ class RateLimit:
|
||||
|
||||
active_requests_count = redis_client.hlen(self.active_requests_key)
|
||||
if active_requests_count >= self.max_active_requests:
|
||||
raise AppInvokeQuotaExceededError("Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests))
|
||||
raise AppInvokeQuotaExceededError(
|
||||
"Too many requests. Please try again later. The current maximum "
|
||||
"concurrent requests allowed is {}.".format(self.max_active_requests)
|
||||
)
|
||||
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
||||
return request_id
|
||||
|
||||
@@ -116,5 +121,5 @@ class RateLimitGenerator:
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.rate_limit.exit(self.request_id)
|
||||
if self.generator is not None and hasattr(self.generator, 'close'):
|
||||
if self.generator is not None and hasattr(self.generator, "close"):
|
||||
self.generator.close()
|
||||
|
||||
@@ -25,25 +25,25 @@ from .variables import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'IntegerVariable',
|
||||
'FloatVariable',
|
||||
'ObjectVariable',
|
||||
'SecretVariable',
|
||||
'StringVariable',
|
||||
'ArrayAnyVariable',
|
||||
'Variable',
|
||||
'SegmentType',
|
||||
'SegmentGroup',
|
||||
'Segment',
|
||||
'NoneSegment',
|
||||
'NoneVariable',
|
||||
'IntegerSegment',
|
||||
'FloatSegment',
|
||||
'ObjectSegment',
|
||||
'ArrayAnySegment',
|
||||
'StringSegment',
|
||||
'ArrayStringVariable',
|
||||
'ArrayNumberVariable',
|
||||
'ArrayObjectVariable',
|
||||
'ArraySegment',
|
||||
"IntegerVariable",
|
||||
"FloatVariable",
|
||||
"ObjectVariable",
|
||||
"SecretVariable",
|
||||
"StringVariable",
|
||||
"ArrayAnyVariable",
|
||||
"Variable",
|
||||
"SegmentType",
|
||||
"SegmentGroup",
|
||||
"Segment",
|
||||
"NoneSegment",
|
||||
"NoneVariable",
|
||||
"IntegerSegment",
|
||||
"FloatSegment",
|
||||
"ObjectSegment",
|
||||
"ArrayAnySegment",
|
||||
"StringSegment",
|
||||
"ArrayStringVariable",
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectVariable",
|
||||
"ArraySegment",
|
||||
]
|
||||
|
||||
@@ -28,12 +28,12 @@ from .variables import (
|
||||
|
||||
|
||||
def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
if (value_type := mapping.get('value_type')) is None:
|
||||
raise VariableError('missing value type')
|
||||
if not mapping.get('name'):
|
||||
raise VariableError('missing name')
|
||||
if (value := mapping.get('value')) is None:
|
||||
raise VariableError('missing value')
|
||||
if (value_type := mapping.get("value_type")) is None:
|
||||
raise VariableError("missing value type")
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
@@ -44,7 +44,7 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
case SegmentType.NUMBER if isinstance(value, float):
|
||||
result = FloatVariable.model_validate(mapping)
|
||||
case SegmentType.NUMBER if not isinstance(value, float | int):
|
||||
raise VariableError(f'invalid number value {value}')
|
||||
raise VariableError(f"invalid number value {value}")
|
||||
case SegmentType.OBJECT if isinstance(value, dict):
|
||||
result = ObjectVariable.model_validate(mapping)
|
||||
case SegmentType.ARRAY_STRING if isinstance(value, list):
|
||||
@@ -54,9 +54,9 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
|
||||
result = ArrayObjectVariable.model_validate(mapping)
|
||||
case _:
|
||||
raise VariableError(f'not supported value type {value_type}')
|
||||
raise VariableError(f"not supported value type {value_type}")
|
||||
if result.size > dify_config.MAX_VARIABLE_SIZE:
|
||||
raise VariableError(f'variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}')
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -73,4 +73,4 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, list):
|
||||
return ArrayAnySegment(value=value)
|
||||
raise ValueError(f'not supported value {value}')
|
||||
raise ValueError(f"not supported value {value}")
|
||||
|
||||
@@ -4,14 +4,14 @@ from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
from . import SegmentGroup, factory
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r'\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}')
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
||||
|
||||
|
||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
||||
parts = re.split(VARIABLE_PATTERN, template)
|
||||
segments = []
|
||||
for part in filter(lambda x: x, parts):
|
||||
if '.' in part and (value := variable_pool.get(part.split('.'))):
|
||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
||||
segments.append(value)
|
||||
else:
|
||||
segments.append(factory.build_segment(part))
|
||||
|
||||
@@ -8,15 +8,15 @@ class SegmentGroup(Segment):
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return ''.join([segment.text for segment in self.value])
|
||||
return "".join([segment.text for segment in self.value])
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return ''.join([segment.log for segment in self.value])
|
||||
return "".join([segment.log for segment in self.value])
|
||||
|
||||
@property
|
||||
def markdown(self):
|
||||
return ''.join([segment.markdown for segment in self.value])
|
||||
return "".join([segment.markdown for segment in self.value])
|
||||
|
||||
def to_object(self):
|
||||
return [segment.to_object() for segment in self.value]
|
||||
|
||||
@@ -14,13 +14,13 @@ class Segment(BaseModel):
|
||||
value_type: SegmentType
|
||||
value: Any
|
||||
|
||||
@field_validator('value_type')
|
||||
@field_validator("value_type")
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
If the value is different, a ValueError is raised.
|
||||
"""
|
||||
if value != cls.model_fields['value_type'].default:
|
||||
if value != cls.model_fields["value_type"].default:
|
||||
raise ValueError("Cannot modify 'value_type'")
|
||||
return value
|
||||
|
||||
@@ -50,15 +50,15 @@ class NoneSegment(Segment):
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return 'null'
|
||||
return "null"
|
||||
|
||||
|
||||
class StringSegment(Segment):
|
||||
@@ -76,24 +76,21 @@ class IntegerSegment(Segment):
|
||||
value: int
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class ObjectSegment(Segment):
|
||||
value_type: SegmentType = SegmentType.OBJECT
|
||||
value: Mapping[str, Any]
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False)
|
||||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
return json.dumps(self.model_dump()['value'], ensure_ascii=False, indent=2)
|
||||
return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
class ArraySegment(Segment):
|
||||
@@ -101,11 +98,11 @@ class ArraySegment(Segment):
|
||||
def markdown(self) -> str:
|
||||
items = []
|
||||
for item in self.value:
|
||||
if hasattr(item, 'to_markdown'):
|
||||
if hasattr(item, "to_markdown"):
|
||||
items.append(item.to_markdown())
|
||||
else:
|
||||
items.append(str(item))
|
||||
return '\n'.join(items)
|
||||
return "\n".join(items)
|
||||
|
||||
|
||||
class ArrayAnySegment(ArraySegment):
|
||||
@@ -126,4 +123,3 @@ class ArrayNumberSegment(ArraySegment):
|
||||
class ArrayObjectSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||
value: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@ from enum import Enum
|
||||
|
||||
|
||||
class SegmentType(str, Enum):
|
||||
NONE = 'none'
|
||||
NUMBER = 'number'
|
||||
STRING = 'string'
|
||||
SECRET = 'secret'
|
||||
ARRAY_ANY = 'array[any]'
|
||||
ARRAY_STRING = 'array[string]'
|
||||
ARRAY_NUMBER = 'array[number]'
|
||||
ARRAY_OBJECT = 'array[object]'
|
||||
OBJECT = 'object'
|
||||
NONE = "none"
|
||||
NUMBER = "number"
|
||||
STRING = "string"
|
||||
SECRET = "secret"
|
||||
ARRAY_ANY = "array[any]"
|
||||
ARRAY_STRING = "array[string]"
|
||||
ARRAY_NUMBER = "array[number]"
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
OBJECT = "object"
|
||||
|
||||
GROUP = 'group'
|
||||
GROUP = "group"
|
||||
|
||||
@@ -23,11 +23,11 @@ class Variable(Segment):
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
default='',
|
||||
default="",
|
||||
description="Unique identity for variable. It's only used by environment variables now.",
|
||||
)
|
||||
name: str
|
||||
description: str = Field(default='', description='Description of the variable.')
|
||||
description: str = Field(default="", description="Description of the variable.")
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
@@ -62,7 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class SecretVariable(StringVariable):
|
||||
value_type: SegmentType = SegmentType.SECRET
|
||||
|
||||
|
||||
@@ -32,10 +32,13 @@ class BasedGenerateTaskPipeline:
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(self, application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -61,35 +64,39 @@ class BasedGenerateTaskPipeline:
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError('Incorrect API key provided')
|
||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = 'error'
|
||||
message.error = err_desc
|
||||
refetch_message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
|
||||
db.session.commit()
|
||||
if refetch_message:
|
||||
err_desc = self._error_to_desc(err)
|
||||
refetch_message.status = "error"
|
||||
refetch_message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return err
|
||||
|
||||
def _error_to_desc(cls, e: Exception) -> str:
|
||||
def _error_to_desc(self, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, QuotaExceededError):
|
||||
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.")
|
||||
return (
|
||||
"Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||
)
|
||||
|
||||
message = getattr(e, 'description', str(e))
|
||||
message = getattr(e, "description", str(e))
|
||||
if not message:
|
||||
message = 'Internal Server Error, please contact support.'
|
||||
message = "Internal Server Error, please contact support."
|
||||
|
||||
return message
|
||||
|
||||
@@ -99,10 +106,7 @@ class BasedGenerateTaskPipeline:
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
return ErrorStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
err=e
|
||||
)
|
||||
return ErrorStreamResponse(task_id=self._application_generate_entity.task_id, err=e)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
@@ -123,11 +127,8 @@ class BasedGenerateTaskPipeline:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
rule=ModerationRule(type=sensitive_word_avoidance.type, config=sensitive_word_avoidance.config),
|
||||
queue_manager=self._queue_manager,
|
||||
)
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
@@ -141,8 +142,7 @@ class BasedGenerateTaskPipeline:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
completion = self._output_moderation_handler.moderation_completion(
|
||||
completion=completion,
|
||||
public_event=False
|
||||
completion=completion, public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
@@ -64,23 +64,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
]
|
||||
|
||||
def __init__(self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -101,18 +99,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
usage=LLMUsage.empty_usage(),
|
||||
)
|
||||
)
|
||||
|
||||
self._conversation_name_generate_thread = None
|
||||
|
||||
def process(
|
||||
self,
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
@@ -125,22 +123,18 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._generate_conversation_name(
|
||||
self._conversation,
|
||||
self._application_generate_entity.query
|
||||
self._conversation, self._application_generate_entity.query
|
||||
)
|
||||
|
||||
generator = self._wrapper_process_stream_response(
|
||||
trace_manager=self._application_generate_entity.trace_manager
|
||||
)
|
||||
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse
|
||||
]:
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -149,11 +143,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {
|
||||
'usage': jsonable_encoder(self._task_state.llm_result.usage)
|
||||
}
|
||||
extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
@@ -164,8 +156,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
@@ -177,18 +169,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
**extras,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
raise Exception("Queue listening stopped unexpectedly.")
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
@@ -198,14 +191,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
@@ -217,15 +210,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
def _wrapper_process_stream_response(self, trace_manager: Optional[TraceQueueManager] = None) -> \
|
||||
Generator[StreamResponse, None, None]:
|
||||
|
||||
def _wrapper_process_stream_response(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get('text_to_speech')
|
||||
if text_to_speech_dict and text_to_speech_dict.get('autoPlay') == 'enabled' and text_to_speech_dict.get('enabled'):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get('voice', None))
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
@@ -250,14 +247,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
break
|
||||
else:
|
||||
start_listener_time = time.time()
|
||||
yield MessageAudioStreamResponse(audio=audio.audio,
|
||||
task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio='', task_id=task_id)
|
||||
yield MessageAudioStreamResponse(audio=audio.audio, task_id=task_id)
|
||||
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
|
||||
|
||||
def _process_stream_response(
|
||||
self,
|
||||
publisher: AppGeneratorTTSPublisher,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
self, publisher: AppGeneratorTTSPublisher, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
@@ -333,9 +327,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
if self._conversation_name_generate_thread:
|
||||
self._conversation_name_generate_thread.join()
|
||||
|
||||
def _save_message(
|
||||
self, trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> None:
|
||||
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
@@ -347,31 +339,32 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
self._model_config.mode, self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.content.strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
self._message.message_metadata = (
|
||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id
|
||||
)
|
||||
)
|
||||
|
||||
@@ -379,11 +372,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT
|
||||
] and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT]
|
||||
and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras,
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
@@ -395,22 +386,17 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model = model_config.model
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle,
|
||||
model=model_config.model
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(self._task_state.llm_result.prompt_messages)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_instance.get_llm_num_tokens(
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
completion_tokens = model_instance.get_llm_num_tokens([self._task_state.llm_result.message])
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
@@ -418,10 +404,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
@@ -429,16 +412,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
extras["metadata"] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
@@ -449,9 +430,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:return:
|
||||
"""
|
||||
return AgentMessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
task_id=self._application_generate_entity.task_id, id=message_id, answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
|
||||
@@ -461,9 +440,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
@@ -478,7 +455,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
tool=agent_thought.tool,
|
||||
tool_labels=agent_thought.tool_labels,
|
||||
tool_input=agent_thought.tool_input,
|
||||
message_files=agent_thought.files
|
||||
message_files=agent_thought.files,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -500,15 +477,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content),
|
||||
),
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
),
|
||||
PublishFrom.TASK_PIPELINE,
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
@@ -16,11 +15,11 @@ from core.app.entities.queue_entities import (
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
@@ -31,12 +30,9 @@ from services.annotation_service import AppAnnotationService
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||
_task_state: Union[EasyUITaskState, WorkflowTaskState]
|
||||
|
||||
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
@@ -45,17 +41,23 @@ class MessageCycleManage:
|
||||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get('auto_generate_conversation_name', True)
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(target=self._generate_conversation_name_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'conversation_id': conversation.id,
|
||||
'query': query
|
||||
})
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation.id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
@@ -63,17 +65,13 @@ class MessageCycleManage:
|
||||
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self,
|
||||
flask_app: Flask,
|
||||
conversation_id: str,
|
||||
query: str):
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(Conversation.id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
app_model = conversation.app
|
||||
@@ -100,12 +98,9 @@ class MessageCycleManage:
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
|
||||
return annotation
|
||||
@@ -119,28 +114,7 @@ class MessageCycleManage:
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
@@ -148,27 +122,23 @@ class MessageCycleManage:
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
if "." in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = '.bin'
|
||||
extension = ".bin"
|
||||
# add sign url to local file
|
||||
if message_file.url.startswith('http'):
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
@@ -177,8 +147,8 @@ class MessageCycleManage:
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or 'user',
|
||||
url=url
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
url=url,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -190,11 +160,7 @@ class MessageCycleManage:
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageStreamResponse(task_id=self._application_generate_entity.task_id, id=message_id, answer=answer)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
@@ -202,7 +168,4 @@ class MessageCycleManage:
|
||||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
answer=answer
|
||||
)
|
||||
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
|
||||
|
||||
@@ -1,33 +1,41 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
NodeExecutionInfo,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_iteration_cycle_manage import WorkflowIterationCycleManage
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
@@ -41,54 +49,56 @@ from models.workflow import (
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: WorkflowTaskState
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
def _handle_workflow_run_start(self) -> WorkflowRun:
|
||||
max_sequence = (
|
||||
db.session.query(db.func.max(WorkflowRun.sequence_number))
|
||||
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id)
|
||||
.filter(WorkflowRun.app_id == self._workflow.app_id)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**user_inputs}
|
||||
for key, value in (system_inputs or {}).items():
|
||||
if key.value == 'conversation':
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(inputs)
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(inputs),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.tenant_id = self._workflow.tenant_id
|
||||
workflow_run.app_id = self._workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = self._workflow.id
|
||||
workflow_run.type = self._workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = self._workflow.version
|
||||
workflow_run.graph = self._workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING.value
|
||||
workflow_run.created_by_role = (
|
||||
CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value
|
||||
)
|
||||
workflow_run.created_by = self._user.id
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
@@ -97,33 +107,37 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
@@ -135,34 +149,64 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
)
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(
|
||||
self, workflow_run: WorkflowRun,
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = WorkflowService.get_elapsed_time(workflow_run_id=workflow_run.id)
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
running_workflow_node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (
|
||||
workflow_node_execution.finished_at - workflow_node_execution.created_at
|
||||
).total_seconds()
|
||||
db.session.commit()
|
||||
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
@@ -178,39 +222,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
def _handle_node_execution_start(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
@@ -219,33 +250,26 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:param event: queue node succeeded event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
@@ -253,33 +277,24 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.elapsed_time = (workflow_node_execution.finished_at - event.start_at).total_seconds()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
@@ -287,8 +302,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
@@ -302,13 +322,14 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp())
|
||||
)
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
def _workflow_finish_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
@@ -348,14 +369,13 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
@@ -363,6 +383,9 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
@@ -374,29 +397,42 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp())
|
||||
)
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
@@ -416,181 +452,153 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
)
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_workflow_start(self) -> WorkflowRun:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs=self._workflow_system_variables
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run started event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
"""
|
||||
Workflow parallel branch finished to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: parallel branch run succeeded or failed event
|
||||
:return:
|
||||
"""
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
"""
|
||||
Workflow iteration start to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration start event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
def _workflow_iteration_next_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||
) -> IterationNodeNextStreamResponse:
|
||||
"""
|
||||
Workflow iteration next to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration next event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
|
||||
execution_metadata = event.execution_metadata if isinstance(event, QueueNodeSucceededEvent) else None
|
||||
|
||||
if self._iteration_state and self._iteration_state.current_iterations:
|
||||
if not execution_metadata:
|
||||
execution_metadata = {}
|
||||
current_iteration_data = None
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if data.parent_iteration_id == None:
|
||||
current_iteration_data = data
|
||||
break
|
||||
|
||||
if current_iteration_data:
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_ID] = current_iteration_data.iteration_id
|
||||
execution_metadata[NodeRunMetadataKey.ITERATION_INDEX] = current_iteration_data.current_index
|
||||
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
def _workflow_iteration_completed_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
"""
|
||||
Workflow iteration completed to stream response
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:param event: iteration completed event
|
||||
:return:
|
||||
"""
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
if execution_metadata and execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if self._iteration_state:
|
||||
for iteration_node_id in self._iteration_state.current_iterations:
|
||||
data = self._iteration_state.current_iterations[iteration_node_id]
|
||||
if execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
data.total_tokens += int(execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=execution_metadata
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(
|
||||
self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
) -> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if conversation_id is None:
|
||||
conversation_id = self._application_generate_entity.inputs.get('sys.conversation_id')
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.',
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
if latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
|
||||
if (workflow_node_execution
|
||||
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=latest_node_execution_info.start_at,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
|
||||
outputs = workflow_node_execution.outputs
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs,
|
||||
conversation_id=conversation_id,
|
||||
trace_manager=trace_manager
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
@@ -641,9 +649,45 @@ class WorkflowCycleManage(WorkflowIterationCycleManage):
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
Refetch workflow run
|
||||
:param workflow_run_id: workflow run id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
|
||||
|
||||
if not workflow_run:
|
||||
raise Exception(f"Workflow run not found: {workflow_run_id}")
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Refetch workflow node execution
|
||||
:param node_execution_id: workflow node execution id
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == self._application_generate_entity.app_config.tenant_id,
|
||||
WorkflowNodeExecution.app_id == self._application_generate_entity.app_config.app_id,
|
||||
WorkflowNodeExecution.workflow_id == self._workflow.id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not workflow_node_execution:
|
||||
raise Exception(f"Workflow node execution not found: {node_execution_id}")
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
@@ -16,31 +16,32 @@ _TEXT_COLOR_MAPPING = {
|
||||
"red": "31;1",
|
||||
}
|
||||
|
||||
|
||||
def get_colored_text(text: str, color: str) -> str:
|
||||
"""Get colored text."""
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||
|
||||
|
||||
def print_text(
|
||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||
) -> None:
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None) -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end, file=file)
|
||||
if file:
|
||||
file.flush() # ensure all printed content are written to file
|
||||
|
||||
|
||||
class DifyAgentCallbackHandler(BaseModel):
|
||||
"""Callback Handler that prints to std out."""
|
||||
color: Optional[str] = ''
|
||||
|
||||
color: Optional[str] = ""
|
||||
current_loop: int = 1
|
||||
|
||||
def __init__(self, color: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or 'green'
|
||||
self.color = color or "green"
|
||||
self.current_loop = 1
|
||||
|
||||
def on_tool_start(
|
||||
@@ -58,7 +59,7 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
tool_outputs: Sequence[ToolInvokeMessage],
|
||||
message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
print_text("\n[on_tool_end]\n", color=self.color)
|
||||
@@ -79,26 +80,21 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> None:
|
||||
"""Do nothing."""
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='red')
|
||||
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color="red")
|
||||
|
||||
def on_agent_start(
|
||||
self, thought: str
|
||||
) -> None:
|
||||
def on_agent_start(self, thought: str) -> None:
|
||||
"""Run on agent start."""
|
||||
if thought:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + \
|
||||
str(self.current_loop) + \
|
||||
"\nThought: " + thought + "\n", color=self.color)
|
||||
print_text(
|
||||
"\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\nThought: " + thought + "\n",
|
||||
color=self.color,
|
||||
)
|
||||
else:
|
||||
print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
def on_agent_finish(
|
||||
self, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
def on_agent_finish(self, color: Optional[str] = None, **kwargs: Any) -> None:
|
||||
"""Run on agent end."""
|
||||
print_text("\n[on_agent_finish]\n Loop: " + str(self.current_loop) + "\n", color=self.color)
|
||||
|
||||
@@ -107,9 +103,9 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
|
||||
|
||||
@property
|
||||
def ignore_chat_model(self) -> bool:
|
||||
"""Whether to ignore chat model callbacks."""
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
|
||||
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != "true"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
@@ -11,11 +10,9 @@ from models.model import DatasetRetrieverResource
|
||||
class DatasetIndexToolCallbackHandler:
|
||||
"""Callback handler for dataset tool."""
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager,
|
||||
app_id: str,
|
||||
message_id: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> None:
|
||||
def __init__(
|
||||
self, queue_manager: AppQueueManager, app_id: str, message_id: str, user_id: str, invoke_from: InvokeFrom
|
||||
) -> None:
|
||||
self._queue_manager = queue_manager
|
||||
self._app_id = app_id
|
||||
self._message_id = message_id
|
||||
@@ -29,11 +26,12 @@ class DatasetIndexToolCallbackHandler:
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
content=query,
|
||||
source='app',
|
||||
source="app",
|
||||
source_app_id=self._app_id,
|
||||
created_by_role=('account'
|
||||
if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self._user_id
|
||||
created_by_role=(
|
||||
"account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
|
||||
),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
@@ -43,18 +41,15 @@ class DatasetIndexToolCallbackHandler:
|
||||
"""Handle tool end."""
|
||||
for document in documents:
|
||||
query = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.index_node_id == document.metadata['doc_id']
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
# if 'dataset_id' in document.metadata:
|
||||
if 'dataset_id' in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
|
||||
synchronize_session=False
|
||||
)
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@@ -64,26 +59,25 @@ class DatasetIndexToolCallbackHandler:
|
||||
for item in resource:
|
||||
dataset_retriever_resource = DatasetRetrieverResource(
|
||||
message_id=self._message_id,
|
||||
position=item.get('position'),
|
||||
dataset_id=item.get('dataset_id'),
|
||||
dataset_name=item.get('dataset_name'),
|
||||
document_id=item.get('document_id'),
|
||||
document_name=item.get('document_name'),
|
||||
data_source_type=item.get('data_source_type'),
|
||||
segment_id=item.get('segment_id'),
|
||||
score=item.get('score') if 'score' in item else None,
|
||||
hit_count=item.get('hit_count') if 'hit_count' else None,
|
||||
word_count=item.get('word_count') if 'word_count' in item else None,
|
||||
segment_position=item.get('segment_position') if 'segment_position' in item else None,
|
||||
index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
|
||||
content=item.get('content'),
|
||||
retriever_from=item.get('retriever_from'),
|
||||
created_by=self._user_id
|
||||
position=item.get("position"),
|
||||
dataset_id=item.get("dataset_id"),
|
||||
dataset_name=item.get("dataset_name"),
|
||||
document_id=item.get("document_id"),
|
||||
document_name=item.get("document_name"),
|
||||
data_source_type=item.get("data_source_type"),
|
||||
segment_id=item.get("segment_id"),
|
||||
score=item.get("score") if "score" in item else None,
|
||||
hit_count=item.get("hit_count") if "hit_count" else None,
|
||||
word_count=item.get("word_count") if "word_count" in item else None,
|
||||
segment_position=item.get("segment_position") if "segment_position" in item else None,
|
||||
index_node_hash=item.get("index_node_hash") if "index_node_hash" in item else None,
|
||||
content=item.get("content"),
|
||||
retriever_from=item.get("retriever_from"),
|
||||
created_by=self._user_id,
|
||||
)
|
||||
db.session.add(dataset_retriever_resource)
|
||||
db.session.commit()
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
@@ -2,4 +2,4 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH
|
||||
|
||||
|
||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||
"""Callback Handler that prints to std out."""
|
||||
"""Callback Handler that prints to std out."""
|
||||
|
||||
@@ -29,9 +29,13 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_indices = []
|
||||
for i, text in enumerate(texts):
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding = db.session.query(Embedding).filter_by(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider).first()
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if embedding:
|
||||
text_embeddings[i] = embedding.get_embedding()
|
||||
else:
|
||||
@@ -41,17 +45,18 @@ class CacheEmbedding(Embeddings):
|
||||
embedding_queue_embeddings = []
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(self._model_instance.model,
|
||||
self._model_instance.credentials)
|
||||
max_chunks = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] \
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties else 1
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties
|
||||
else 1
|
||||
)
|
||||
for i in range(0, len(embedding_queue_texts), max_chunks):
|
||||
batch_texts = embedding_queue_texts[i:i + max_chunks]
|
||||
batch_texts = embedding_queue_texts[i : i + max_chunks]
|
||||
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=batch_texts,
|
||||
user=self._user
|
||||
)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=batch_texts, user=self._user)
|
||||
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
@@ -60,16 +65,18 @@ class CacheEmbedding(Embeddings):
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except Exception as e:
|
||||
logging.exception('Failed transform embedding: ', e)
|
||||
logging.exception("Failed transform embedding: ", e)
|
||||
cache_embeddings = []
|
||||
try:
|
||||
for i, embedding in zip(embedding_queue_indices, embedding_queue_embeddings):
|
||||
text_embeddings[i] = embedding
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider)
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
embedding_cache.set_embedding(embedding)
|
||||
db.session.add(embedding_cache)
|
||||
cache_embeddings.append(hash)
|
||||
@@ -78,7 +85,7 @@ class CacheEmbedding(Embeddings):
|
||||
db.session.rollback()
|
||||
except Exception as ex:
|
||||
db.session.rollback()
|
||||
logger.error('Failed to embed documents: ', ex)
|
||||
logger.error("Failed to embed documents: ", ex)
|
||||
raise ex
|
||||
|
||||
return text_embeddings
|
||||
@@ -87,16 +94,13 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f'{self._model_instance.provider}_{self._model_instance.model}_{hash}'
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
|
||||
try:
|
||||
embedding_result = self._model_instance.invoke_text_embedding(
|
||||
texts=[text],
|
||||
user=self._user
|
||||
)
|
||||
embedding_result = self._model_instance.invoke_text_embedding(texts=[text], user=self._user)
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
@@ -116,6 +120,6 @@ class CacheEmbedding(Embeddings):
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
except:
|
||||
logging.exception('Failed to add embedding to redis')
|
||||
logging.exception("Failed to add embedding to redis")
|
||||
|
||||
return embedding_results
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user