Compare commits

...

90 Commits

Author SHA1 Message Date
-LAN-
967eb81112 chore: bump version to 0.14.0 (#11679)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-16 15:49:17 +08:00
zxhlyh
9f602f73eb fix: workflow continue on error edge color (#11689) 2024-12-16 15:39:53 +08:00
Joel
41de7e76ec fix: iteration output array type causes always outputting string array (#11686) 2024-12-16 15:06:03 +08:00
Joel
607a22ad12 fix: tool constant params change cause page crashed (#11682) 2024-12-16 14:33:00 +08:00
wangbin77
4b402c4041 fix: enhance workflow.tool_published performance (#11640)
Co-authored-by: wangbin <wangbin35@xiaomi.com>
2024-12-16 13:05:38 +08:00
zhongliliu-butterfly
daccb10d8c fix: volcengine_maas and baichuan message error (#11625)
Co-authored-by: zhongliliu <liuzlx@digitalchina.com>
2024-12-16 13:05:27 +08:00
Kazuhisa Wada
63f1dd7877 Make max_submit_count configurable via Config (#11673) 2024-12-16 12:59:37 +08:00
zhaobingshuang
79801f5c30 fix: deepseek reports an error when using Response Format #11677 (#11678)
Co-authored-by: zhaobs <zhaobs@cailian.net>
2024-12-16 12:58:03 +08:00
非法操作
9c7a1bc067 fix: change http node params from dict to list tuple (#11665) 2024-12-15 21:27:39 +08:00
非法操作
cf0ff88120 feat: add grok-2-1212 and grok-2-vision-1212 (#11672) 2024-12-15 21:18:24 +08:00
Novice
e0b67536e0 fix: remove the unused QueueWorkflowPartialSuccessEvent handle in workflow (#11669)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-15 21:18:14 +08:00
github-actions[bot]
94c7dcc7f1 chore: translate i18n files (#11639)
Co-authored-by: douxc <7553076+douxc@users.noreply.github.com>
2024-12-15 17:22:45 +08:00
luckylhb90
38e155d819 feat: log add trace id (#11599)
Co-authored-by: hobo.l <hobo.l@binance.com>
2024-12-15 17:22:25 +08:00
yihong
efd5575683 fix: _handle_workflow_run_partial_success args is wrong (#11562)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:22:13 +08:00
yihong
1a7c213405 fix: ExternalDatasetService.process_external_api wrong args (#11586)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:22:03 +08:00
yihong
8e3d60c359 fix: account.id should account_id (#11628)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:18:17 +08:00
Bowen Liang
924b4fe742 test: run vdb tests on TiDB Vector with docker in CI tests (#11645) 2024-12-15 17:16:40 +08:00
yihong
7e154a467b fix: better error message for stream (#11635)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:16:04 +08:00
IWAI, Masaharu
b90f1581be Update translate to Japanese: natural Japanese expression (#11647)
Co-authored-by: IWAI, Masaharu <iwai_masaharu@funkit.co.jp>
2024-12-15 17:15:24 +08:00
yihong
821992e21f fix: langfuse do not have created_at args and fix the typing in the file (#11648)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:13:46 +08:00
IWAI, Masaharu
f0c0ce9db1 fix: rename README filename: Japanese language code is 'JA' (#11651) 2024-12-15 17:13:34 +08:00
Junyan Qin
8ecb9aaa91 fix: remove unnecessary curly braces in wf api doc (#11658) 2024-12-15 17:12:26 +08:00
yihong
22258fb0bf fix: filter bug for keywork cause code can not reach (#11666)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-15 17:12:06 +08:00
NFish
a725b8bb6e Feat: new entry point for app creation (#10847) 2024-12-13 17:29:09 +08:00
Kevin9703
bdfdccd511 fix: app log filter value error (#11624) 2024-12-13 16:40:34 +08:00
yihong
194bc60429 fix: split dir for opendal tests (#11627)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-13 16:31:00 +08:00
-LAN-
430ca3322b chore(dependency): bump gunicorn to 23.0 (#11560)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-13 16:16:58 +08:00
zxhlyh
3d803c2e80 Fix/pdf preview in build (#11621) 2024-12-13 11:01:53 +08:00
Hiroshi Fujita
fa3dcbb3bc feat(devcontainer): add alias to stop Docker containers (#11616) 2024-12-13 10:03:58 +08:00
yihong
ee342063d8 ci: better print version for ruff to check the change (#11587)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-12 21:44:00 +08:00
JasonVV
bb3bc60f83 feat(model): add vertex_ai Gemini 2.0 Flash Exp (#11604) 2024-12-12 20:20:49 +08:00
crazywoola
e7a4cfac4d fix: name of llama-3.3-70b-specdec (#11596) 2024-12-12 16:33:49 +08:00
Alok Shrivastwa
6478aa1c9d Added new models and Removed the deleted ones for Groq #11455 (#11456)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Alok Shrivastwa <Alok.Shrivastwa@microland.com>
2024-12-12 14:11:30 +08:00
Warren Chen
7b5839335a [ref] use one method to get boto client for aws bedrock (#11506) 2024-12-12 13:56:52 +08:00
github-actions[bot]
a360af8687 chore: translate i18n files (#11577)
Co-authored-by: JzoNgKVO <27049666+JzoNgKVO@users.noreply.github.com>
2024-12-12 13:47:39 +08:00
yihong
36cb25b341 fix: support mdx files close #11557 (#11565)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-12 13:37:56 +08:00
Joe
e565ecdaef fix: change workflow trace id (#11585) 2024-12-12 13:37:29 +08:00
KVOJJJin
f96fdc2970 Feat: dark mode for logs and annotations (#11575) 2024-12-12 10:09:48 +08:00
Jiang
0d04cdc323 Lindorm vdb (#11574)
Co-authored-by: jiangzhijie <jiangzhijie.jzj@alibaba-inc.com>
2024-12-12 09:43:27 +08:00
非法操作
926f604f09 feat: add gemini-2.0-flash-exp (#11570) 2024-12-12 09:33:39 +08:00
yihong
180743612c fix: better opendal tests (#11569)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-12 09:33:30 +08:00
liuzhenghua
d05f189049 Fix: RateLimit requests were not released when a streaming generation exception occurred (#11540) 2024-12-11 19:16:35 +08:00
github-actions[bot]
ceaa9f1101 chore: translate i18n files (#11545)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
2024-12-11 18:04:14 +08:00
zxhlyh
6f4cbe0bde fix: workflow continue on error doc link (#11554) 2024-12-11 18:03:41 +08:00
-LAN-
8d4bb9b40d feat: integrate opendal storage (#11508)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-11 14:50:54 +08:00
Novice
1765fe2a29 fix: iteration node in parallel mode token count error (#11539)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-11 14:23:01 +08:00
Novice
79a710ce98 Feat: continue on error (#11458)
Co-authored-by: Novice Lee <novicelee@NovicedeMacBook-Pro.local>
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-11 14:22:42 +08:00
zxhlyh
bec5451f12 feat: workflow continue on error (#11474) 2024-12-11 14:21:38 +08:00
Yi Xiao
86dfdcb8ec chore: update thai lang in app page (#11541) 2024-12-11 12:08:09 +08:00
Tommy
42d986b96d [Pixtral] Add new model ; add vision (#11231) 2024-12-11 10:14:16 +08:00
zkyTech
fbc4ca980c fix: Remove duplicate 'response_format' parameter from model YAML files (#11531)
Co-authored-by: zhangkunyuan <zhangkunyuan@cmhi.chinamobile.com>
2024-12-11 10:10:53 +08:00
Paul van Oorschot
80c52e0ea4 feat: Add llama-3.3 models for Groq (#11533) 2024-12-11 09:59:46 +08:00
yihong
50b76dd5a2 fix: better error message for url add external knowledge (#11537)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-11 09:55:48 +08:00
yihong
225fcd5e41 Revert "fix: total tokens is wrong which is zero in inter way, close … (#11536) 2024-12-11 09:54:46 +08:00
yihong
afffd345bc fix: can not start local by REMOTE_SETTINGS_SOURCE_NAME change it to … (#11535)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-11 09:35:25 +08:00
yihong
716576043d fix: issue 11247 that Completion mode content maybe list or str (#11504)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-10 23:22:14 +08:00
Wei Mingzhi
28231d39a4 Remove the processing of single quote when testing API tools. (#11390) 2024-12-10 19:53:38 +08:00
非法操作
9e23c3d625 chore: LOCAL_FILE also try to use remote_url as Prompt message (#11443) 2024-12-10 10:56:49 +08:00
Charlie.Wei
bdd5869244 Msg file preview (#11466)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-10 10:53:37 +08:00
barabicu
fc1415d705 chore: fix typo in Japanese localization (#11502) 2024-12-10 09:29:16 +08:00
문정현
8218f62478 chore : fix translation Typo in ko-KR localization (#11509) 2024-12-10 09:09:26 +08:00
-LAN-
fd354d999d fix(app_generator_service): overload type hints (#11507)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-10 09:06:34 +08:00
orangeclk
ec00b25793 feat: add siliconflow qwq and llama3.3 model (#11492) 2024-12-10 08:49:45 +08:00
huanshare
967b7d89e3 feat:add apollo configuration to load env file (#11210)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: huanshare <liuhuan101@longfor.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-10 02:51:20 +08:00
Yingchun Lai
32f8439143 fix: add the missing abab6.5t-chat model of Minimax (#11484) 2024-12-09 17:59:20 +08:00
-LAN-
0ff8bd2aa9 chore: bump version to 0.13.2 (#11489)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-09 17:57:23 +08:00
Hash Brown
2866383228 fix: cannot close notification manually (#11490) 2024-12-09 17:55:06 +08:00
Jyong
00ac7edeb3 improve message clean logic (#11487) 2024-12-09 16:12:30 +08:00
-LAN-
537068cfde refactor(iteration_node): use Sequence and Mapping in parameters (#11483)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-09 15:41:20 +08:00
suzuki.sh
c3c6a48059 Fix the token count at the iteration node (#11235)
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-09 15:02:04 +08:00
zhaobingshuang
5c166b3f40 fix: tags could not be saved when the Workflow Tool was created (#11481)
Co-authored-by: zhaobs <zhaobs@cailian.net>
2024-12-09 14:38:02 +08:00
kurokobo
230fa3286b feat: add 'Open in Explore' link for each apps on studio (#11402) 2024-12-09 12:04:03 +08:00
Muneyuki Noguchi
061c0b10fd Fix the Japanese translation for 'Detail' (#11476) 2024-12-09 11:18:28 +08:00
Yi Xiao
32f8a98cf8 feat: ifelse condition variable editable after selection (#11431) 2024-12-09 11:06:47 +08:00
Charlie.Wei
6c60ecb237 Refactor: Remove redundant style and simplify Mermaid component (#11472) 2024-12-09 09:47:58 +08:00
xiandan-erizo
c3fae5e801 Update ext_redis.py (#11214) 2024-12-09 09:35:52 +08:00
VoidIsVoid
a594e256ae remove mermail render cache (#11470)
Co-authored-by: Gimling <huangjl@ruyi.ai>
2024-12-09 09:33:18 +08:00
Trey Dong
41d90c2408 fix(api): throw error when notion block can not find (#11433) 2024-12-09 09:10:59 +08:00
yihong
7ff42b1b7a fix: unit tests env will need clear too (#11445)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-09 09:04:11 +08:00
Kazuki Takamatsu
4d7cfd0de5 Fix model provider of vertex ai (#11437) 2024-12-08 08:44:49 +08:00
Hash Brown
266d32bd77 fix: cannot upload animated webp image as app icon (#11453) 2024-12-08 08:37:21 +08:00
非法操作
7e1184c071 feat: support json_schema for ollama models (#11449) 2024-12-08 08:36:12 +08:00
非法操作
1ce51e57ab feat: add gemini exp 1206 (#11444) 2024-12-07 22:28:10 +08:00
非法操作
142b4fd699 feat: add zhipu glm_4v_flash (#11440) 2024-12-07 22:27:57 +08:00
Hash Brown
cc8feaa483 style: EmojiPicker component top padding (#11452) 2024-12-07 22:26:28 +08:00
yihong
d9d5d35a77 fix: issue #10596 by making the iteration node outputs right (#11394)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-07 16:28:15 +08:00
Huỳnh Gia Bôi
9277156b6c fix(document_extractor): pptx file type and missing metadata_filename UnstructuredIO (#11364)
Co-authored-by: Julian Huynh <julian.huynh@immersio.io>
2024-12-06 18:55:59 +08:00
KVOJJJin
1490a19fa1 Fix: compatible with outputs data structure (#11432) 2024-12-06 17:35:35 +08:00
Jyong
9b7adcd4d9 update tidb batch get endpoint to basic mode (#11426) 2024-12-06 17:06:46 +08:00
Jyong
a8d32f9964 fix external retrieval without segment id (#11423) 2024-12-06 14:45:15 +08:00
431 changed files with 12062 additions and 6658 deletions

View File

@@ -7,5 +7,6 @@ echo 'alias start-api="cd /workspaces/dify/api && poetry run python -m flask run
echo 'alias start-worker="cd /workspaces/dify/api && poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc
echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify down"' >> ~/.bashrc
source /home/vscode/.bashrc
source /home/vscode/.bashrc

View File

@@ -37,6 +37,7 @@ jobs:
- name: Ruff check
if: steps.changed-files.outputs.any_changed == 'true'
run: |
poetry run -C api ruff --version
poetry run -C api ruff check ./api
poetry run -C api ruff format --check ./api

View File

@@ -51,7 +51,7 @@ jobs:
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
- name: Set up Vector Stores (TiDB, Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase)
uses: hoverkraft-tech/compose-action@v2.0.2
with:
compose-file: |
@@ -67,6 +67,7 @@ jobs:
pgvector
chroma
elasticsearch
tidb
- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh

View File

@@ -56,20 +56,36 @@ DB_DATABASE=dify
# Storage configuration
# use for store upload files, private keys...
# storage type: local, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=local
STORAGE_LOCAL_PATH=storage
# storage type: opendal, s3, aliyun-oss, azure-blob, baidu-obs, google-storage, huawei-obs, oci-storage, tencent-cos, volcengine-tos, supabase
STORAGE_TYPE=opendal
# Apache OpenDAL storage configuration, refer to https://github.com/apache/opendal
STORAGE_OPENDAL_SCHEME=fs
# OpenDAL FS
OPENDAL_FS_ROOT=storage
# OpenDAL S3
OPENDAL_S3_ROOT=/
OPENDAL_S3_BUCKET=your-bucket-name
OPENDAL_S3_ENDPOINT=https://s3.amazonaws.com
OPENDAL_S3_ACCESS_KEY_ID=your-access-key
OPENDAL_S3_SECRET_ACCESS_KEY=your-secret-key
OPENDAL_S3_REGION=your-region
OPENDAL_S3_SERVER_SIDE_ENCRYPTION=
# S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com
S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key
S3_REGION=your-region
# Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration
ALIYUN_OSS_BUCKET_NAME=your-bucket-name
ALIYUN_OSS_ACCESS_KEY=your-access-key
@@ -79,6 +95,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
@@ -125,8 +142,8 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
# Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -277,6 +294,7 @@ VIKINGDB_SOCKET_TIMEOUT=30
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
@@ -381,6 +399,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
# Indexing configuration
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
@@ -413,3 +433,5 @@ RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100

View File

@@ -1,11 +1,51 @@
from pydantic_settings import SettingsConfigDict
import logging
from typing import Any
from configs.deploy import DeploymentConfig
from configs.enterprise import EnterpriseFeatureConfig
from configs.extra import ExtraServiceConfig
from configs.feature import FeatureConfig
from configs.middleware import MiddlewareConfig
from configs.packaging import PackagingInfo
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource
logger = logging.getLogger(__name__)
class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]):
super().__init__(settings_cls)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def __call__(self) -> dict[str, Any]:
current_state = self.current_state
remote_source_name = current_state.get("REMOTE_SETTINGS_SOURCE_NAME")
if not remote_source_name:
return {}
remote_source: RemoteSettingsSource | None = None
match remote_source_name:
case RemoteSettingsSourceName.APOLLO:
remote_source = ApolloSettingsSource(current_state)
case _:
logger.warning(f"Unsupported remote source: {remote_source_name}")
return {}
d: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
field_value, field_key, value_is_complex = remote_source.get_field_value(field, field_name)
field_value = remote_source.prepare_field_value(field_name, field, field_value, value_is_complex)
if field_value is not None:
d[field_key] = field_value
return d
class DifyConfig(
@@ -19,6 +59,8 @@ class DifyConfig(
MiddlewareConfig,
# Extra service configs
ExtraServiceConfig,
# Remote source configs
RemoteSettingsSourceConfig,
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
@@ -35,3 +77,20 @@ class DifyConfig(
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.
# Thanks for your concentration and consideration.
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
)

View File

@@ -439,6 +439,17 @@ class WorkflowConfig(BaseSettings):
)
class WorkflowNodeExecutionConfig(BaseSettings):
"""
Configuration for workflow node execution
"""
MAX_SUBMIT_COUNT: PositiveInt = Field(
description="Maximum number of submitted thread count in a ThreadPool for parallel node execution",
default=100,
)
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
@@ -775,6 +786,7 @@ class FeatureConfig(
ToolConfig,
UpdateConfig,
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
LoginConfig,
# hosted services config

View File

@@ -1,54 +1,69 @@
from typing import Any, Optional
from typing import Any, Literal, Optional
from urllib.parse import quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings
from configs.middleware.cache.redis_config import RedisConfig
from configs.middleware.storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from configs.middleware.storage.amazon_s3_storage_config import S3StorageConfig
from configs.middleware.storage.azure_blob_storage_config import AzureBlobStorageConfig
from configs.middleware.storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from configs.middleware.storage.google_cloud_storage_config import GoogleCloudStorageConfig
from configs.middleware.storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from configs.middleware.storage.oci_storage_config import OCIStorageConfig
from configs.middleware.storage.supabase_storage_config import SupabaseStorageConfig
from configs.middleware.storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from configs.middleware.storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from configs.middleware.vdb.analyticdb_config import AnalyticdbConfig
from configs.middleware.vdb.baidu_vector_config import BaiduVectorDBConfig
from configs.middleware.vdb.chroma_config import ChromaConfig
from configs.middleware.vdb.couchbase_config import CouchbaseConfig
from configs.middleware.vdb.elasticsearch_config import ElasticsearchConfig
from configs.middleware.vdb.lindorm_config import LindormConfig
from configs.middleware.vdb.milvus_config import MilvusConfig
from configs.middleware.vdb.myscale_config import MyScaleConfig
from configs.middleware.vdb.oceanbase_config import OceanBaseVectorConfig
from configs.middleware.vdb.opensearch_config import OpenSearchConfig
from configs.middleware.vdb.oracle_config import OracleConfig
from configs.middleware.vdb.pgvector_config import PGVectorConfig
from configs.middleware.vdb.pgvectors_config import PGVectoRSConfig
from configs.middleware.vdb.qdrant_config import QdrantConfig
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.upstash_config import UpstashConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig
from .cache.redis_config import RedisConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
from .storage.baidu_obs_storage_config import BaiduOBSStorageConfig
from .storage.google_cloud_storage_config import GoogleCloudStorageConfig
from .storage.huawei_obs_storage_config import HuaweiCloudOBSStorageConfig
from .storage.oci_storage_config import OCIStorageConfig
from .storage.opendal_storage_config import OpenDALStorageConfig
from .storage.supabase_storage_config import SupabaseStorageConfig
from .storage.tencent_cos_storage_config import TencentCloudCOSStorageConfig
from .storage.volcengine_tos_storage_config import VolcengineTOSStorageConfig
from .vdb.analyticdb_config import AnalyticdbConfig
from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
from .vdb.opensearch_config import OpenSearchConfig
from .vdb.oracle_config import OracleConfig
from .vdb.pgvector_config import PGVectorConfig
from .vdb.pgvectors_config import PGVectoRSConfig
from .vdb.qdrant_config import QdrantConfig
from .vdb.relyt_config import RelytConfig
from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
class StorageConfig(BaseSettings):
STORAGE_TYPE: str = Field(
STORAGE_TYPE: Literal[
"opendal",
"s3",
"aliyun-oss",
"azure-blob",
"baidu-obs",
"google-storage",
"huawei-obs",
"oci-storage",
"tencent-cos",
"volcengine-tos",
"supabase",
"local",
] = Field(
description="Type of storage to use."
" Options: 'local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', 'huawei-obs', "
"'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'local'.",
default="local",
" Options: 'opendal', '(deprecated) local', 's3', 'aliyun-oss', 'azure-blob', 'baidu-obs', 'google-storage', "
"'huawei-obs', 'oci-storage', 'tencent-cos', 'volcengine-tos', 'supabase'. Default is 'opendal'.",
default="opendal",
)
STORAGE_LOCAL_PATH: str = Field(
description="Path for local storage when STORAGE_TYPE is set to 'local'.",
default="storage",
deprecated=True,
)
@@ -73,7 +88,7 @@ class KeywordStoreConfig(BaseSettings):
)
class DatabaseConfig:
class DatabaseConfig(BaseSettings):
DB_HOST: str = Field(
description="Hostname or IP address of the database server.",
default="localhost",
@@ -235,6 +250,7 @@ class MiddlewareConfig(
GoogleCloudStorageConfig,
HuaweiCloudOBSStorageConfig,
OCIStorageConfig,
OpenDALStorageConfig,
S3StorageConfig,
SupabaseStorageConfig,
TencentCloudCOSStorageConfig,

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class BaiduOBSStorageConfig(BaseModel):
class BaiduOBSStorageConfig(BaseSettings):
"""
Configuration settings for Baidu Object Storage Service (OBS)
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class HuaweiCloudOBSStorageConfig(BaseModel):
class HuaweiCloudOBSStorageConfig(BaseSettings):
"""
Configuration settings for Huawei Cloud Object Storage Service (OBS)
"""

View File

@@ -0,0 +1,51 @@
from enum import StrEnum
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
class OpenDALScheme(StrEnum):
FS = "fs"
S3 = "s3"
class OpenDALStorageConfig(BaseSettings):
STORAGE_OPENDAL_SCHEME: str = Field(
default=OpenDALScheme.FS.value,
description="OpenDAL scheme.",
)
# FS
OPENDAL_FS_ROOT: str = Field(
default="storage",
description="Root path for local storage.",
)
# S3
OPENDAL_S3_ROOT: str = Field(
default="/",
description="Root path for S3 storage.",
)
OPENDAL_S3_BUCKET: str = Field(
default="",
description="S3 bucket name.",
)
OPENDAL_S3_ENDPOINT: str = Field(
default="https://s3.amazonaws.com",
description="S3 endpoint URL.",
)
OPENDAL_S3_ACCESS_KEY_ID: str = Field(
default="",
description="S3 access key ID.",
)
OPENDAL_S3_SECRET_ACCESS_KEY: str = Field(
default="",
description="S3 secret access key.",
)
OPENDAL_S3_REGION: str = Field(
default="",
description="S3 region.",
)
OPENDAL_S3_SERVER_SIDE_ENCRYPTION: Literal["aws:kms", ""] = Field(
default="",
description="S3 server-side encryption.",
)

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class SupabaseStorageConfig(BaseModel):
class SupabaseStorageConfig(BaseSettings):
"""
Configuration settings for Supabase Object Storage Service
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseModel):
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class AnalyticdbConfig(BaseModel):
class AnalyticdbConfig(BaseSettings):
"""
Configuration for connecting to Alibaba Cloud AnalyticDB for PostgreSQL.
Refer to the following documentation for details on obtaining credentials:

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class CouchbaseConfig(BaseModel):
class CouchbaseConfig(BaseSettings):
"""
Couchbase configs
"""

View File

@@ -21,3 +21,14 @@ class LindormConfig(BaseSettings):
description="Lindorm password",
default=None,
)
DEFAULT_INDEX_TYPE: Optional[str] = Field(
description="Lindorm Vector Index Type, hnsw or flat is available in dify",
default="hnsw",
)
DEFAULT_DISTANCE_TYPE: Optional[str] = Field(
description="Vector Distance Type, support l2, cosinesimil, innerproduct", default="l2"
)
USING_UGC_INDEX: Optional[bool] = Field(
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, Field, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
class MyScaleConfig(BaseModel):
class MyScaleConfig(BaseSettings):
"""
Configuration settings for MyScale vector database
"""

View File

@@ -1,9 +1,10 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from pydantic_settings import BaseSettings
class VikingDBConfig(BaseModel):
class VikingDBConfig(BaseSettings):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:

View File

@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
default="0.13.1",
default="0.14.0",
)
COMMIT_SHA: str = Field(

View File

@@ -0,0 +1,17 @@
from typing import Optional
from pydantic import Field
from .apollo import ApolloSettingsSourceInfo
from .base import RemoteSettingsSource
from .enums import RemoteSettingsSourceName
class RemoteSettingsSourceConfig(ApolloSettingsSourceInfo):
REMOTE_SETTINGS_SOURCE_NAME: RemoteSettingsSourceName | str = Field(
description="name of remote config source",
default="",
)
__all__ = ["RemoteSettingsSource", "RemoteSettingsSourceConfig", "RemoteSettingsSourceName"]

View File

@@ -0,0 +1,55 @@
from collections.abc import Mapping
from typing import Any, Optional
from pydantic import Field
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings
from configs.remote_settings_sources.base import RemoteSettingsSource
from .client import ApolloClient
class ApolloSettingsSourceInfo(BaseSettings):
"""
Packaging build information
"""
APOLLO_APP_ID: Optional[str] = Field(
description="apollo app_id",
default=None,
)
APOLLO_CLUSTER: Optional[str] = Field(
description="apollo cluster",
default=None,
)
APOLLO_CONFIG_URL: Optional[str] = Field(
description="apollo config url",
default=None,
)
APOLLO_NAMESPACE: Optional[str] = Field(
description="apollo namespace",
default=None,
)
class ApolloSettingsSource(RemoteSettingsSource):
def __init__(self, configs: Mapping[str, Any]):
self.client = ApolloClient(
app_id=configs["APOLLO_APP_ID"],
cluster=configs["APOLLO_CLUSTER"],
config_url=configs["APOLLO_CONFIG_URL"],
start_hot_update=False,
_notification_map={configs["APOLLO_NAMESPACE"]: -1},
)
self.namespace = configs["APOLLO_NAMESPACE"]
self.remote_configs = self.client.get_all_dicts(self.namespace)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
if not isinstance(self.remote_configs, dict):
raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
field_value = self.remote_configs.get(field_name)
return field_value, field_name, False

View File

@@ -0,0 +1,303 @@
import hashlib
import json
import logging
import os
import threading
import time
from pathlib import Path
from .python_3x import http_request, makedirs_wrapper
from .utils import (
CONFIGURATIONS,
NAMESPACE_NAME,
NOTIFICATION_ID,
get_value_from_dict,
init_ip,
no_key_cache_key,
signature,
url_encode_wrapper,
)
logger = logging.getLogger(__name__)
class ApolloClient:
def __init__(
self,
config_url,
app_id,
cluster="default",
secret="",
start_hot_update=True,
change_listener=None,
_notification_map=None,
):
# Core routing parameters
self.config_url = config_url
self.cluster = cluster
self.app_id = app_id
# Non-core parameters
self.ip = init_ip()
self.secret = secret
# Check the parameter variables
# Private control variables
self._cycle_time = 5
self._stopping = False
self._cache = {}
self._no_key = {}
self._hash = {}
self._pull_timeout = 75
self._cache_file_path = os.path.expanduser("~") + "/.dify/config/remote-settings/apollo/cache/"
self._long_poll_thread = None
self._change_listener = change_listener # "add" "delete" "update"
if _notification_map is None:
_notification_map = {"application": -1}
self._notification_map = _notification_map
self.last_release_key = None
# Private startup method
self._path_checker()
if start_hot_update:
self._start_hot_update()
# start the heartbeat thread
heartbeat = threading.Thread(target=self._heart_beat)
heartbeat.daemon = True
heartbeat.start()
def get_json_from_net(self, namespace="application"):
url = "{}/configs/{}/{}/{}?releaseKey={}&ip={}".format(
self.config_url, self.app_id, self.cluster, namespace, "", self.ip
)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"get_json_from_net load configs failed, body is {body}")
return None
data = json.loads(body)
data = data["configurations"]
return_data = {CONFIGURATIONS: data}
return return_data
else:
return None
except Exception:
logger.exception("an error occurred in get_json_from_net")
return None
def get_value(self, key, default_val=None, namespace="application"):
try:
# read memory configuration
namespace_cache = self._cache.get(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
return val
no_key = no_key_cache_key(namespace, key)
if no_key in self._no_key:
return default_val
# read the network configuration
namespace_data = self.get_json_from_net(namespace)
val = get_value_from_dict(namespace_data, key)
if val is not None:
self._update_cache_and_file(namespace_data, namespace)
return val
# read the file configuration
namespace_cache = self._get_local_cache(namespace)
val = get_value_from_dict(namespace_cache, key)
if val is not None:
self._update_cache_and_file(namespace_cache, namespace)
return val
# If all of them are not obtained, the default value is returned
# and the local cache is set to None
self._set_local_cache_none(namespace, key)
return default_val
except Exception:
logger.exception("get_value has error, [key is %s], [namespace is %s]", key, namespace)
return default_val
# Set the key of a namespace to none, and do not set default val
# to ensure the real-time correctness of the function call.
# If the user does not have the same default val twice
# and the default val is used here, there may be a problem.
def _set_local_cache_none(self, namespace, key):
no_key = no_key_cache_key(namespace, key)
self._no_key[no_key] = key
def _start_hot_update(self):
self._long_poll_thread = threading.Thread(target=self._listener)
# When the asynchronous thread is started, the daemon thread will automatically exit
# when the main thread is launched.
self._long_poll_thread.daemon = True
self._long_poll_thread.start()
def stop(self):
self._stopping = True
logger.info("Stopping listener...")
# Call the set callback function, and if it is abnormal, try it out
def _call_listener(self, namespace, old_kv, new_kv):
if self._change_listener is None:
return
if old_kv is None:
old_kv = {}
if new_kv is None:
new_kv = {}
try:
for key in old_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if new_value is None:
# If newValue is empty, it means key, and the value is deleted.
self._change_listener("delete", namespace, key, old_value)
continue
if new_value != old_value:
self._change_listener("update", namespace, key, new_value)
continue
for key in new_kv:
new_value = new_kv.get(key)
old_value = old_kv.get(key)
if old_value is None:
self._change_listener("add", namespace, key, new_value)
except BaseException as e:
logger.warning(str(e))
def _path_checker(self):
if not os.path.isdir(self._cache_file_path):
makedirs_wrapper(self._cache_file_path)
# update the local cache and file cache
def _update_cache_and_file(self, namespace_data, namespace="application"):
# update the local cache
self._cache[namespace] = namespace_data
# update the file cache
new_string = json.dumps(namespace_data)
new_hash = hashlib.md5(new_string.encode("utf-8")).hexdigest()
if self._hash.get(namespace) == new_hash:
pass
else:
file_path = Path(self._cache_file_path) / f"{self.app_id}_configuration_{namespace}.txt"
file_path.write_text(new_string)
self._hash[namespace] = new_hash
# get the configuration from the local file
def _get_local_cache(self, namespace="application"):
cache_file_path = os.path.join(self._cache_file_path, f"{self.app_id}_configuration_{namespace}.txt")
if os.path.isfile(cache_file_path):
with open(cache_file_path) as f:
result = json.loads(f.readline())
return result
return {}
def _long_poll(self):
notifications = []
for key in self._cache:
namespace_data = self._cache[key]
notification_id = -1
if NOTIFICATION_ID in namespace_data:
notification_id = self._cache[key][NOTIFICATION_ID]
notifications.append({NAMESPACE_NAME: key, NOTIFICATION_ID: notification_id})
try:
# if the length is 0 it is returned directly
if len(notifications) == 0:
return
url = "{}/notifications/v2".format(self.config_url)
params = {
"appId": self.app_id,
"cluster": self.cluster,
"notifications": json.dumps(notifications, ensure_ascii=False),
}
param_str = url_encode_wrapper(params)
url = url + "?" + param_str
code, body = http_request(url, self._pull_timeout, headers=self._sign_headers(url))
http_code = code
if http_code == 304:
logger.debug("No change, loop...")
return
if http_code == 200:
if not body:
logger.error(f"_long_poll load configs failed,body is {body}")
return
data = json.loads(body)
for entry in data:
namespace = entry[NAMESPACE_NAME]
n_id = entry[NOTIFICATION_ID]
logger.info("%s has changes: notificationId=%d", namespace, n_id)
self._get_net_and_set_local(namespace, n_id, call_change=True)
return
else:
logger.warning("Sleep...")
except Exception as e:
logger.warning(str(e))
def _get_net_and_set_local(self, namespace, n_id, call_change=False):
namespace_data = self.get_json_from_net(namespace)
if not namespace_data:
return
namespace_data[NOTIFICATION_ID] = n_id
old_namespace = self._cache.get(namespace)
self._update_cache_and_file(namespace_data, namespace)
if self._change_listener is not None and call_change and old_namespace:
old_kv = old_namespace.get(CONFIGURATIONS)
new_kv = namespace_data.get(CONFIGURATIONS)
self._call_listener(namespace, old_kv, new_kv)
def _listener(self):
logger.info("start long_poll")
while not self._stopping:
self._long_poll()
time.sleep(self._cycle_time)
logger.info("stopped, long_poll")
# add the need for endorsement to the header
def _sign_headers(self, url):
headers = {}
if self.secret == "":
return headers
uri = url[len(self.config_url) : len(url)]
time_unix_now = str(int(round(time.time() * 1000)))
headers["Authorization"] = "Apollo " + self.app_id + ":" + signature(time_unix_now, uri, self.secret)
headers["Timestamp"] = time_unix_now
return headers
def _heart_beat(self):
while not self._stopping:
for namespace in self._notification_map:
self._do_heart_beat(namespace)
time.sleep(60 * 10) # 10分钟
def _do_heart_beat(self, namespace):
url = "{}/configs/{}/{}/{}?ip={}".format(self.config_url, self.app_id, self.cluster, namespace, self.ip)
try:
code, body = http_request(url, timeout=3, headers=self._sign_headers(url))
if code == 200:
if not body:
logger.error(f"_do_heart_beat load configs failed,body is {body}")
return None
data = json.loads(body)
if self.last_release_key == data["releaseKey"]:
return None
self.last_release_key = data["releaseKey"]
data = data["configurations"]
self._update_cache_and_file(data, namespace)
else:
return None
except Exception:
logger.exception("an error occurred in _do_heart_beat")
return None
def get_all_dicts(self, namespace):
namespace_data = self._cache.get(namespace)
if namespace_data is None:
net_namespace_data = self.get_json_from_net(namespace)
if not net_namespace_data:
return namespace_data
namespace_data = net_namespace_data.get(CONFIGURATIONS)
if namespace_data:
self._update_cache_and_file(namespace_data, namespace)
return namespace_data

View File

@@ -0,0 +1,41 @@
import logging
import os
import ssl
import urllib.request
from urllib import parse
from urllib.error import HTTPError
# Create an SSL context that allows for a lower level of security
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("HIGH:!DH:!aNULL")
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
# Create an opener object and pass in a custom SSL context
opener = urllib.request.build_opener(urllib.request.HTTPSHandler(context=ssl_context))
urllib.request.install_opener(opener)
logger = logging.getLogger(__name__)
def http_request(url, timeout, headers={}):
try:
request = urllib.request.Request(url, headers=headers)
res = urllib.request.urlopen(request, timeout=timeout)
body = res.read().decode("utf-8")
return res.code, body
except HTTPError as e:
if e.code == 304:
logger.warning("http_request error,code is 304, maybe you should check secret")
return 304, None
logger.warning("http_request error,code is %d, msg is %s", e.code, e.msg)
raise e
def url_encode(params):
return parse.urlencode(params)
def makedirs_wrapper(path):
os.makedirs(path, exist_ok=True)

View File

@@ -0,0 +1,51 @@
import hashlib
import socket
from .python_3x import url_encode
# define constants
CONFIGURATIONS = "configurations"
NOTIFICATION_ID = "notificationId"
NAMESPACE_NAME = "namespaceName"
# add timestamps uris and keys
def signature(timestamp, uri, secret):
import base64
import hmac
string_to_sign = "" + timestamp + "\n" + uri
hmac_code = hmac.new(secret.encode(), string_to_sign.encode(), hashlib.sha1).digest()
return base64.b64encode(hmac_code).decode()
def url_encode_wrapper(params):
return url_encode(params)
def no_key_cache_key(namespace, key):
return "{}{}{}".format(namespace, len(namespace), key)
# Returns whether the obtained value is obtained, and None if it does not
def get_value_from_dict(namespace_cache, key):
if namespace_cache:
kv_data = namespace_cache.get(CONFIGURATIONS)
if kv_data is None:
return None
if key in kv_data:
return kv_data[key]
return None
def init_ip():
ip = ""
s = None
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 53))
ip = s.getsockname()[0]
finally:
if s:
s.close()
return ip

View File

@@ -0,0 +1,15 @@
from collections.abc import Mapping
from typing import Any
from pydantic.fields import FieldInfo
class RemoteSettingsSource:
def __init__(self, configs: Mapping[str, Any]):
pass
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
raise NotImplementedError
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
return value

View File

@@ -0,0 +1,5 @@
from enum import StrEnum
class RemoteSettingsSourceName(StrEnum):
APOLLO = "apollo"

View File

@@ -14,11 +14,11 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])

View File

@@ -1,5 +1,6 @@
from datetime import UTC, datetime
from flask import request
from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
@@ -20,8 +21,17 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
current_tenant_id = current_user.current_tenant_id
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
if app_id:
installed_apps = (
db.session.query(InstalledApp)
.filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_apps = [

View File

@@ -368,6 +368,7 @@ class ToolWorkflowProviderCreateApi(Resource):
description=args["description"],
parameters=args["parameters"],
privacy_policy=args["privacy_policy"],
labels=args["labels"],
)

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -36,6 +36,29 @@ logger = logging.getLogger(__name__)
class AdvancedChatAppGenerator(MessageBasedAppGenerator):
_dialogue_count: int
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
@@ -44,7 +67,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.

View File

@@ -19,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -31,6 +32,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -127,7 +129,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = []
self.total_tokens: int = 0
def process(self):
"""
@@ -318,7 +319,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
@@ -361,8 +362,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@@ -376,7 +375,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation.id,
@@ -387,6 +386,29 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
@@ -404,6 +426,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(

View File

@@ -2,7 +2,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Union
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -28,6 +28,39 @@ logger = logging.getLogger(__name__)
class AgentChatAppGenerator(MessageBasedAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
@@ -36,7 +69,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str, None, None]:
):
"""
Generate App response.

View File

@@ -82,7 +82,7 @@ class AppGenerateResponseConverter(ABC):
for resource in metadata["retriever_resources"]:
updated_resources.append(
{
"segment_id": resource["segment_id"],
"segment_id": resource.get("segment_id", ""),
"position": resource["position"],
"document_name": resource["document_name"],
"score": resource["score"],

View File

@@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@@ -34,9 +34,9 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
@@ -44,19 +44,29 @@ class ChatAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Union[Mapping[str, Any], Generator[str, None, None]]: ...
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Any,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[dict, Generator[str, None, None]]:
):
"""
Generate App response.

View File

@@ -1,7 +1,7 @@
import logging
import threading
import uuid
from collections.abc import Generator
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, current_app
@@ -34,9 +34,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[True] = True,
streaming: Literal[True],
) -> Generator[str, None, None]: ...
@overload
@@ -44,14 +44,29 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
self,
app_model: App,
user: Union[Account, EndUser],
args: dict,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
stream: Literal[False] = False,
) -> dict: ...
streaming: Literal[False],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True
) -> Union[dict, Generator[str, None, None]]:
self,
app_model: App,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
):
"""
Generate App response.

View File

@@ -3,7 +3,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -30,6 +30,35 @@ logger = logging.getLogger(__name__)
class WorkflowAppGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Generator[str, None, None]: ...
@overload
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
@@ -41,7 +70,20 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Mapping[str, Any] | Generator[str, None, None]:
) -> Mapping[str, Any] | Generator[str, None, None]: ...
def generate(
self,
*,
app_model: App,
workflow: Workflow,
user: Account | EndUser,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files

View File

@@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
@@ -34,7 +35,8 @@ class WorkflowAppQueueManager(AppQueueManager):
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()

View File

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -106,7 +108,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -258,36 +259,36 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event)
response = self._workflow_node_start_to_stream_response(
node_start_response = self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if node_start_response:
yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
response = self._workflow_node_finish_to_stream_response(
node_success_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
if node_success_response:
yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response = self._workflow_node_finish_to_stream_response(
node_failed_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
if node_failed_response:
yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
@@ -320,8 +321,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not workflow_run:
raise Exception("Workflow run not initialized.")
# FIXME for issue #11221 quick fix maybe have a better solution
self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0
yield self._workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event
)
@@ -335,7 +334,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_run = self._handle_workflow_run_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens or self.total_tokens,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
@@ -345,6 +344,30 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
@@ -354,7 +377,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
@@ -366,6 +388,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log

View File

@@ -8,6 +8,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -18,6 +19,7 @@ from core.app.entities.queue_entities import (
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
@@ -25,6 +27,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@@ -32,6 +35,7 @@ from core.workflow.graph_engine.entities.event import (
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@@ -176,8 +180,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
@@ -253,6 +261,36 @@ class WorkflowBasedAppRunner(AppRunner):
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum, StrEnum
from typing import Any, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
@@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
@@ -113,18 +115,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None
@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")
class QueueIterationCompletedEvent(AppQueueEvent):
"""
@@ -249,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):
event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int
class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""
event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None
class QueueNodeStartedEvent(AppQueueEvent):
@@ -343,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str
class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""
event: QueueEvent = QueueEvent.NODE_EXCEPTION
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity

View File

@@ -213,6 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED

View File

@@ -110,7 +110,7 @@ class RateLimitGenerator:
raise StopIteration
try:
return next(self.generator)
except StopIteration:
except Exception:
self.close()
raise

View File

@@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
@@ -164,6 +165,55 @@ class WorkflowCycleManage:
return workflow_run
def _handle_workflow_run_partial_success(
self,
workflow_run: WorkflowRun,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
db.session.close()
return workflow_run
def _handle_workflow_run_failed(
self,
workflow_run: WorkflowRun,
@@ -174,6 +224,7 @@ class WorkflowCycleManage:
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
@@ -193,7 +244,7 @@ class WorkflowCycleManage:
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
db.session.commit()
running_workflow_node_executions = (
@@ -318,7 +369,7 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
@@ -337,7 +388,11 @@ class WorkflowCycleManage:
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
@@ -351,8 +406,11 @@ class WorkflowCycleManage:
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
@@ -433,6 +491,7 @@ class WorkflowCycleManage:
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict),
exceptions_count=workflow_run.exceptions_count,
),
)
@@ -483,7 +542,10 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:

View File

@@ -141,7 +141,7 @@ def _to_url(f: File, /):
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
if f.related_id is None:
raise ValueError("Missing file related_id")
return helpers.get_signed_file_url(upload_file_id=f.related_id)
return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id)
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
# add sign url
if f.related_id is None or f.extension is None:

View File

@@ -24,6 +24,12 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
"""Raised when the maximum number of retries is exceeded."""
pass
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
@@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if retries <= max_retries:
time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1)))
raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}")
raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}")
def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):

View File

@@ -10,6 +10,7 @@ from core.model_runtime.entities.llm_entities import (
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
@@ -105,7 +106,11 @@ class BaichuanLanguageModel(LargeLanguageModel):
if isinstance(message.content, str):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError("User message content must be str")
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
message_dict = {"role": "user", "content": message_content.data}
elif message_content.type == PromptMessageContentType.IMAGE:
raise ValueError("Content object type not support image_url")
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}

View File

@@ -0,0 +1,21 @@
import boto3
from botocore.config import Config
def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
else:
# use iam without aksk to call
client = boto3.client(service_name=service_name, config=client_config)
return client

View File

@@ -40,6 +40,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
@@ -173,13 +174,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)
bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

View File

@@ -1,8 +1,5 @@
from typing import Optional
import boto3
from botocore.config import Config
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
@@ -14,6 +11,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
class BedrockRerankModel(RerankModel):
@@ -48,13 +46,7 @@ class BedrockRerankModel(RerankModel):
return RerankResult(model=model, docs=docs)
# initialize client
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:

View File

@@ -3,8 +3,6 @@ import logging
import time
from typing import Optional
import boto3
from botocore.config import Config
from botocore.exceptions import (
ClientError,
EndpointConnectionError,
@@ -25,6 +23,7 @@ from core.model_runtime.errors.invoke import (
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client
logger = logging.getLogger(__name__)
@@ -48,14 +47,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)
embeddings = []
token_usage = 0

View File

@@ -24,6 +24,9 @@ class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel):
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
self._add_custom_parameters(credentials)
# {"response_format": "xx"} need convert to {"response_format": {"type": "xx"}}
if "response_format" in model_parameters:
model_parameters["response_format"] = {"type": model_parameters.get("response_format")}
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
def validate_credentials(self, model: str, credentials: dict) -> None:

View File

@@ -0,0 +1,39 @@
model: gemini-2.0-flash-exp
label:
en_US: Gemini 2.0 Flash Exp
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,38 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -1,4 +1,5 @@
- llama-3.1-405b-reasoning
- llama-3.3-70b-versatile
- llama-3.1-70b-versatile
- llama-3.1-8b-instant
- llama3-70b-8192

View File

@@ -0,0 +1,25 @@
model: gemma-7b-it
label:
zh_Hans: Gemma 7B Instruction Tuned
en_US: Gemma 7B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@@ -0,0 +1,25 @@
model: gemma2-9b-it
label:
zh_Hans: Gemma 2 9B Instruction Tuned
en_US: Gemma 2 9B Instruction Tuned
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.1'
unit: '0.000001'
currency: USD

View File

@@ -1,7 +1,8 @@
model: llama-3.1-70b-versatile
deprecated: true
label:
zh_Hans: Llama-3.1-70b-versatile
en_US: Llama-3.1-70b-versatile
zh_Hans: Llama-3.1-70b-versatile (DEPRECATED)
en_US: Llama-3.1-70b-versatile (DEPRECATED)
model_type: llm
features:
- agent-thought

View File

@@ -1,4 +1,5 @@
model: llama-3.2-11b-text-preview
deprecated: true
label:
zh_Hans: Llama 3.2 11B Text (Preview)
en_US: Llama 3.2 11B Text (Preview)

View File

@@ -1,4 +1,5 @@
model: llama-3.2-90b-text-preview
depraceted: true
label:
zh_Hans: Llama 3.2 90B Text (Preview)
en_US: Llama 3.2 90B Text (Preview)

View File

@@ -0,0 +1,25 @@
model: llama-3.3-70b-specdec
label:
zh_Hans: Llama 3.3 70B Specdec
en_US: Llama 3.3 70B Specdec
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

View File

@@ -0,0 +1,25 @@
model: llama-3.3-70b-versatile
label:
zh_Hans: Llama 3.3 70B Versatile
en_US: Llama 3.3 70B Versatile
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 32768
pricing:
input: "0.05"
output: "0.1"
unit: "0.000001"
currency: USD

View File

@@ -0,0 +1,25 @@
model: llama3-groq-70b-8192-tool-use-preview
label:
zh_Hans: Llama3-groq-70b-8192-tool-use (PREVIEW)
en_US: Llama3-groq-70b-8192-tool-use (PREVIEW)
model_type: llm
features:
- agent-thought
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: max_tokens
use_template: max_tokens
default: 512
min: 1
max: 8192
pricing:
input: '0.05'
output: '0.08'
unit: '0.000001'
currency: USD

View File

@@ -35,6 +35,7 @@ from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage
class MinimaxLargeLanguageModel(LargeLanguageModel):
model_apis = {
"abab7-chat-preview": MinimaxChatCompletionPro,
"abab6.5t-chat": MinimaxChatCompletionPro,
"abab6.5s-chat": MinimaxChatCompletionPro,
"abab6.5-chat": MinimaxChatCompletionPro,
"abab6-chat": MinimaxChatCompletionPro,

View File

@@ -1,3 +1,5 @@
- pixtral-large-latest
- pixtral-large-2411
- pixtral-12b-2409
- codestral-latest
- mistral-embed

View File

@@ -5,6 +5,7 @@ label:
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
@@ -21,7 +22,7 @@ parameter_rules:
max: 1
- name: max_tokens
use_template: max_tokens
default: 1024
default: 8192
min: 1
max: 8192
- name: safe_prompt

View File

@@ -0,0 +1,52 @@
model: pixtral-large-2411
label:
zh_Hans: pixtral-large-2411
en_US: pixtral-large-2411
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: safe_prompt
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@@ -0,0 +1,52 @@
model: pixtral-large-latest
label:
zh_Hans: pixtral-large-latest
en_US: pixtral-large-latest
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
default: 0.7
min: 0
max: 1
- name: top_p
use_template: top_p
default: 1
min: 0
max: 1
- name: max_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: safe_prompt
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
zh_Hans: 是否开启提示词审查
label:
en_US: SafePrompt
zh_Hans: 提示词审查
- name: random_seed
type: int
help:
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
label:
en_US: RandomSeed
zh_Hans: 随机数种子
default: 0
min: 0
max: 2147483647
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@@ -181,9 +181,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
# prepare the payload for a simple ping to the model
data = {"model": model, "stream": stream}
if "format" in model_parameters:
data["format"] = model_parameters["format"]
del model_parameters["format"]
if format_schema := model_parameters.pop("format", None):
try:
data["format"] = format_schema if format_schema == "json" else json.loads(format_schema)
except json.JSONDecodeError as e:
raise InvokeBadRequestError(f"Invalid format schema: {str(e)}")
if "keep_alive" in model_parameters:
data["keep_alive"] = model_parameters["keep_alive"]
@@ -733,12 +735,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
ParameterRule(
name="format",
label=I18nObject(en_US="Format", zh_Hans="返回格式"),
type=ParameterType.STRING,
type=ParameterType.TEXT,
default="json",
help=I18nObject(
en_US="the format to return a response in. Currently the only accepted value is json.",
zh_Hans="返回响应的格式。目前唯一接受的值是json。",
en_US="the format to return a response in. Format can be `json` or a JSON schema.",
zh_Hans="返回响应的格式。目前接受的值是字符串`json`或JSON schema.",
),
options=["json"],
),
],
pricing=PriceConfig(

View File

@@ -478,6 +478,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
usage=usage,
)
break
# handle the error here. for issue #11629
if chunk_json.get("error") and chunk_json.get("choices") is None:
raise ValueError(chunk_json.get("error"))
if chunk_json:
if u := chunk_json.get("usage"):
usage = u

View File

@@ -1,4 +1,5 @@
- Tencent/Hunyuan-A52B-Instruct
- Qwen/QwQ-32B-Preview
- Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2.5-32B-Instruct
- Qwen/Qwen2.5-14B-Instruct
@@ -19,6 +20,7 @@
- 01-ai/Yi-1.5-6B-Chat
- internlm/internlm2_5-20b-chat
- internlm/internlm2_5-7b-chat
- meta-llama/Llama-3.3-70B-Instruct
- meta-llama/Meta-Llama-3.1-405B-Instruct
- meta-llama/Meta-Llama-3.1-70B-Instruct
- meta-llama/Meta-Llama-3.1-8B-Instruct

View File

@@ -0,0 +1,53 @@
model: meta-llama/Llama-3.3-70B-Instruct
label:
en_US: meta-llama/Llama-3.3-70B-Instruct
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '4.13'
output: '4.13'
unit: '0.000001'
currency: RMB

View File

@@ -0,0 +1,53 @@
model: Qwen/QwQ-32B-Preview
label:
en_US: Qwen/QwQ-32B-Preview
model_type: llm
features:
- agent-thought
- tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: max_tokens
use_template: max_tokens
type: int
default: 512
min: 1
max: 4096
help:
zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。
en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: frequency_penalty
use_template: frequency_penalty
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '1.26'
output: '1.26'
unit: '0.000001'
currency: RMB

View File

@@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@@ -58,8 +58,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@@ -59,8 +59,6 @@ parameter_rules:
help:
zh_Hans: 生成时使用的随机数种子用户控制模型生成内容的随机性。支持无符号64位整数默认值为 1234。在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同。
en_US: The random number seed used when generating, the user controls the randomness of the content generated by the model. Supports unsigned 64-bit integers, default value is 1234. When using seed, the model will try its best to generate the same or similar results, but there is currently no guarantee that the results will be exactly the same every time.
- name: response_format
use_template: response_format
- name: repetition_penalty
required: false
type: float

View File

@@ -0,0 +1,39 @@
model: gemini-2.0-flash-exp
label:
en_US: Gemini 2.0 Flash Exp
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -104,13 +104,14 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
"""
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]
token = ""
# get access token from service account credential
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
@@ -478,10 +479,11 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
if stop:
config_kwargs["stop_sequences"] = stop
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:

View File

@@ -48,10 +48,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:
@@ -100,10 +101,11 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:return:
"""
try:
service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"]))
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
if service_account_info:
if service_account_key:
service_account_info = json.loads(base64.b64decode(service_account_key))
service_accountSA = service_account.Credentials.from_service_account_info(service_account_info)
aiplatform.init(credentials=service_accountSA, project=project_id, location=location)
else:

View File

@@ -68,7 +68,12 @@ class MaaSClient(MaasService):
content = []
for message_content in message.content:
if message_content.type == PromptMessageContentType.TEXT:
raise ValueError("Content object type only support image_url")
content.append(
{
"type": "text",
"text": message_content.data,
}
)
elif message_content.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, message_content)
image_data = re.sub(r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data)

View File

@@ -0,0 +1,66 @@
model: grok-2-1212
label:
en_US: grok-2-1212
model_type: llm
features:
- agent-thought
- tool-call
- multi-tool-call
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -0,0 +1,64 @@
model: grok-2-vision-1212
label:
en_US: grok-2-vision-1212
model_type: llm
features:
- agent-thought
- vision
model_properties:
mode: chat
context_size: 8192
parameter_rules:
- name: temperature
label:
en_US: "Temperature"
zh_Hans: "采样温度"
type: float
default: 0.7
min: 0.0
max: 2.0
precision: 1
required: true
help:
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: top_p
label:
en_US: "Top P"
zh_Hans: "Top P"
type: float
default: 0.7
min: 0.0
max: 1.0
precision: 1
required: true
help:
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
- name: frequency_penalty
use_template: frequency_penalty
label:
en_US: "Frequency Penalty"
zh_Hans: "频率惩罚"
type: float
default: 0
min: 0
max: 2.0
precision: 1
required: false
help:
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
- name: user
use_template: text
label:
en_US: "User"
zh_Hans: "用户"
type: string
required: false
help:
en_US: "Used to track and differentiate conversation requests from different users."
zh_Hans: "用于追踪和区分不同用户的对话请求。"

View File

@@ -1,6 +1,6 @@
model: grok-beta
label:
en_US: Grok Beta
en_US: grok-beta
model_type: llm
features:
- agent-thought

View File

@@ -1,6 +1,6 @@
model: grok-vision-beta
label:
en_US: Grok Vision Beta
en_US: grok-vision-beta
model_type: llm
features:
- agent-thought

View File

@@ -0,0 +1,52 @@
model: glm-4v-flash
label:
en_US: glm-4v-flash
model_type: llm
model_properties:
mode: chat
context_size: 2048
features:
- vision
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.6
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 1024
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: RMB

View File

@@ -144,7 +144,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}:
if isinstance(copy_prompt_message.content, list):
# check if model is 'glm-4v'
if model not in {"glm-4v", "glm-4v-plus"}:
if not model.startswith("glm-4v"):
# not support list message
continue
# get image and
@@ -188,7 +188,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
model_parameters["tools"] = [web_search_params]
if model in {"glm-4v", "glm-4v-plus"}:
if model.startswith("glm-4v"):
params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters)
else:
params = {"model": model, "messages": [], **model_parameters}
@@ -412,6 +412,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"

View File

@@ -4,7 +4,7 @@ import os
from datetime import datetime, timedelta
from typing import Optional
from langfuse import Langfuse
from langfuse import Langfuse # type: ignore
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import LangfuseConfig
@@ -65,8 +65,11 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id
trace_id = trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
trace_id = trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
@@ -76,22 +79,20 @@ class LangFuseDataTrace(BaseTraceInstance):
name=name,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
metadata=trace_info.metadata,
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
created_at=trace_info.start_time,
updated_at=trace_info.end_time,
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
trace_id=trace_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
metadata=metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error or "",
)
@@ -103,7 +104,7 @@ class LangFuseDataTrace(BaseTraceInstance):
name=TraceTaskName.WORKFLOW_TRACE.value,
input=trace_info.workflow_run_inputs,
output=trace_info.workflow_run_outputs,
metadata=trace_info.metadata,
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
)
@@ -192,7 +193,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=(trace_info.workflow_app_log_id or trace_info.workflow_run_id),
parent_observation_id=trace_info.workflow_run_id,
)
else:
span_data = LangfuseSpan(
@@ -239,11 +240,13 @@ class LangFuseDataTrace(BaseTraceInstance):
file_list = trace_info.file_list
metadata = trace_info.metadata
message_data = trace_info.message_data
if message_data is None:
return
message_id = message_data.id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: EndUser = (
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
@@ -300,6 +303,8 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_generation(langfuse_generation_data)
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
span_data = LangfuseSpan(
name=TraceTaskName.MODERATION_TRACE.value,
input=trace_info.inputs,
@@ -319,9 +324,11 @@ class LangFuseDataTrace(BaseTraceInstance):
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
generation_usage = GenerationUsage(
total=len(str(trace_info.suggested_question)),
input=len(trace_info.inputs),
input=len(trace_info.inputs) if trace_info.inputs else 0,
output=len(trace_info.suggested_question),
unit=UnitEnum.CHARACTERS,
)
@@ -342,6 +349,8 @@ class LangFuseDataTrace(BaseTraceInstance):
self.add_generation(langfuse_generation_data=generation_data)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
dataset_retrieval_span_data = LangfuseSpan(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
input=trace_info.inputs,

View File

@@ -62,15 +62,17 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.message_id or trace_info.workflow_app_log_id or trace_info.workflow_run_id
trace_id = trace_info.message_id or trace_info.workflow_run_id
message_dotted_order = (
generate_dotted_order(trace_info.message_id, trace_info.start_time) if trace_info.message_id else None
)
workflow_dotted_order = generate_dotted_order(
trace_info.workflow_app_log_id or trace_info.workflow_run_id,
trace_info.workflow_run_id,
trace_info.workflow_data.created_at,
message_dotted_order,
)
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
message_run = LangSmithRunModel(
@@ -82,7 +84,7 @@ class LangSmithDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
extra={
"metadata": trace_info.metadata,
"metadata": metadata,
},
tags=["message", "workflow"],
error=trace_info.error,
@@ -94,7 +96,7 @@ class LangSmithDataTrace(BaseTraceInstance):
langsmith_run = LangSmithRunModel(
file_list=trace_info.file_list,
total_tokens=trace_info.total_tokens,
id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE.value,
inputs=trace_info.workflow_run_inputs,
run_type=LangSmithRunType.tool,
@@ -102,7 +104,7 @@ class LangSmithDataTrace(BaseTraceInstance):
end_time=trace_info.workflow_data.finished_at,
outputs=trace_info.workflow_run_outputs,
extra={
"metadata": trace_info.metadata,
"metadata": metadata,
},
error=trace_info.error,
tags=["workflow"],
@@ -204,7 +206,7 @@ class LangSmithDataTrace(BaseTraceInstance):
extra={
"metadata": metadata,
},
parent_run_id=trace_info.workflow_app_log_id or trace_info.workflow_run_id,
parent_run_id=trace_info.workflow_run_id,
tags=["node_execution"],
id=node_execution_id,
trace_id=trace_id,

View File

@@ -1,13 +1,10 @@
import copy
import json
import logging
from collections.abc import Iterable
from typing import Any, Optional
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_fixed
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -23,11 +20,15 @@ logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("lindorm").setLevel(logging.WARN)
ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormVectorStoreConfig(BaseModel):
hosts: str
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
@model_validator(mode="before")
@classmethod
@@ -41,9 +42,7 @@ class LindormVectorStoreConfig(BaseModel):
return values
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts,
}
params = {"hosts": self.hosts}
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params
@@ -51,9 +50,21 @@ class LindormVectorStoreConfig(BaseModel):
class LindormVectorStore(BaseVector):
def __init__(self, collection_name: str, config: LindormVectorStoreConfig, **kwargs):
super().__init__(collection_name.lower())
self._routing = None
self._routing_field = None
if config.using_ugc:
routing_value: str = kwargs.get("routing_value")
if routing_value is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
self._routing = routing_value.lower()
self._routing_field = ROUTING_FIELD
ugc_index_name = collection_name
super().__init__(ugc_index_name.lower())
else:
super().__init__(collection_name.lower())
self._client_config = config
self._client = OpenSearch(**config.to_opensearch_params())
self._using_ugc = config.using_ugc
self.kwargs = kwargs
def get_type(self) -> str:
@@ -66,89 +77,37 @@ class LindormVectorStore(BaseVector):
def refresh(self):
self._client.indices.refresh(index=self._collection_name)
def __filter_existed_ids(
self,
texts: list[str],
metadatas: list[dict],
ids: list[str],
bulk_size: int = 1024,
) -> tuple[Iterable[str], Optional[list[dict]], Optional[list[str]]]:
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_ids(batch_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(index=self._collection_name, body={"ids": batch_ids}, _source=False)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch {batch_ids}")
return set()
@retry(stop=stop_after_attempt(3), wait=wait_fixed(60))
def __fetch_existing_routing_ids(batch_ids: list[str], route_ids: list[str]) -> set[str]:
try:
existing_docs = self._client.mget(
body={
"docs": [
{"_index": self._collection_name, "_id": id, "routing": routing}
for id, routing in zip(batch_ids, route_ids)
]
},
_source=False,
)
return {doc["_id"] for doc in existing_docs["docs"] if doc["found"]}
except Exception as e:
logger.exception(f"Error fetching batch ids: {batch_ids}")
return set()
if ids is None:
return texts, metadatas, ids
if len(texts) != len(ids):
raise RuntimeError(f"texts {len(texts)} != {ids}")
filtered_texts = []
filtered_metadatas = []
filtered_ids = []
def batch(iterable, n):
length = len(iterable)
for idx in range(0, length, n):
yield iterable[idx : min(idx + n, length)]
for ids_batch, texts_batch, metadatas_batch in zip(
batch(ids, bulk_size),
batch(texts, bulk_size),
batch(metadatas, bulk_size) if metadatas is not None else batch([None] * len(ids), bulk_size),
):
existing_ids_set = __fetch_existing_ids(ids_batch)
for text, metadata, doc_id in zip(texts_batch, metadatas_batch, ids_batch):
if doc_id not in existing_ids_set:
filtered_texts.append(text)
filtered_ids.append(doc_id)
if metadatas is not None:
filtered_metadatas.append(metadata)
return filtered_texts, metadatas if metadatas is None else filtered_metadatas, filtered_ids
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
actions = []
uuids = self._get_uuids(documents)
for i in range(len(documents)):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuids[i],
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
action_header = {
"index": {
"_index": self.collection_name.lower(),
"_id": uuids[i],
}
}
actions.append(action)
bulk(self._client, actions)
self.refresh()
action_values = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
}
if self._using_ugc:
action_header["index"]["routing"] = self._routing
action_values[self._routing_field] = self._routing
actions.append(action_header)
actions.append(action_values)
response = self._client.bulk(actions)
if response["errors"]:
for item in response["items"]:
print(f"{item['index']['status']}: {item['index']['error']['type']}")
else:
self.refresh()
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}}
query = {"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY.value}.{key}.keyword": value}}]}}}
if self._using_ugc:
query["query"]["bool"]["must"].append({"term": {f"{self._routing_field}.keyword": self._routing}})
response = self._client.search(index=self._collection_name, body=query)
if response["hits"]["hits"]:
return [hit["_id"] for hit in response["hits"]["hits"]]
@@ -156,50 +115,62 @@ class LindormVectorStore(BaseVector):
return None
def delete_by_metadata_field(self, key: str, value: str):
query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
results = self._client.search(index=self._collection_name, body=query_str)
ids = [hit["_id"] for hit in results["hits"]["hits"]]
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self.delete_by_ids(ids)
def delete_by_ids(self, ids: list[str]) -> None:
params = {}
if self._using_ugc:
params["routing"] = self._routing
for id in ids:
if self._client.exists(index=self._collection_name, id=id):
self._client.delete(index=self._collection_name, id=id)
if self._client.exists(index=self._collection_name, id=id, params=params):
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.delete(index=self._collection_name, id=id, params=params)
self.refresh()
else:
logger.warning(f"DELETE BY ID: ID {id} does not exist in the index.")
def delete(self) -> None:
try:
if self._using_ugc:
routing_filter_query = {
"query": {"bool": {"must": [{"term": {f"{self._routing_field}.keyword": self._routing}}]}}
}
self._client.delete_by_query(self._collection_name, body=routing_filter_query)
self.refresh()
else:
if self._client.indices.exists(index=self._collection_name):
self._client.indices.delete(index=self._collection_name, params={"timeout": 60})
logger.info("Delete index success")
else:
logger.warning(f"Index '{self._collection_name}' does not exist. No deletion performed.")
except Exception as e:
logger.exception(f"Error occurred while deleting the index: {self._collection_name}")
raise e
def text_exists(self, id: str) -> bool:
try:
self._client.get(index=self._collection_name, id=id)
params = {}
if self._using_ugc:
params["routing"] = self._routing
self._client.get(index=self._collection_name, id=id, params=params)
return True
except:
return False
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Make sure query_vector is a list
if not isinstance(query_vector, list):
raise ValueError("query_vector should be a list of floats")
# Check whether query_vector is a floating-point number list
if not all(isinstance(x, float) for x in query_vector):
raise ValueError("All elements in query_vector should be floats")
top_k = kwargs.get("top_k", 10)
query = default_vector_search_query(query_vector=query_vector, k=top_k, **kwargs)
try:
response = self._client.search(index=self._collection_name, body=query)
params = {}
if self._using_ugc:
params["routing"] = self._routing
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception as e:
logger.exception(f"Error executing vector search, query: {query}")
raise
@@ -232,7 +203,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match = kwargs.get("minimum_should_match", 0)
top_k = kwargs.get("top_k", 10)
filters = kwargs.get("filter")
routing = kwargs.get("routing")
routing = self._routing
full_text_query = default_text_search_query(
query_text=query,
k=top_k,
@@ -243,6 +214,7 @@ class LindormVectorStore(BaseVector):
minimum_should_match=minimum_should_match,
filters=filters,
routing=routing,
routing_field=self._routing_field,
)
response = self._client.search(index=self._collection_name, body=full_text_query)
docs = []
@@ -265,17 +237,18 @@ class LindormVectorStore(BaseVector):
logger.info(f"Collection {self._collection_name} already exists.")
return
if self._client.indices.exists(index=self._collection_name):
logger.info("{self._collection_name.lower()} already exists.")
logger.info(f"{self._collection_name.lower()} already exists.")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return
if len(self.kwargs) == 0 and len(kwargs) != 0:
self.kwargs = copy.deepcopy(kwargs)
vector_field = kwargs.pop("vector_field", Field.VECTOR.value)
shards = kwargs.pop("shards", 2)
shards = kwargs.pop("shards", 4)
engine = kwargs.pop("engine", "lvector")
method_name = kwargs.pop("method_name", "hnsw")
method_name = kwargs.pop("method_name", dify_config.DEFAULT_INDEX_TYPE)
space_type = kwargs.pop("space_type", dify_config.DEFAULT_DISTANCE_TYPE)
data_type = kwargs.pop("data_type", "float")
space_type = kwargs.pop("space_type", "cosinesimil")
hnsw_m = kwargs.pop("hnsw_m", 24)
hnsw_ef_construction = kwargs.pop("hnsw_ef_construction", 500)
@@ -288,10 +261,10 @@ class LindormVectorStore(BaseVector):
mapping = default_text_mapping(
dimension,
method_name,
space_type=space_type,
shards=shards,
engine=engine,
data_type=data_type,
space_type=space_type,
vector_field=vector_field,
hnsw_m=hnsw_m,
hnsw_ef_construction=hnsw_ef_construction,
@@ -301,6 +274,7 @@ class LindormVectorStore(BaseVector):
centroids_hnsw_m=centroids_hnsw_m,
centroids_hnsw_ef_construct=centroids_hnsw_ef_construct,
centroids_hnsw_ef_search=centroids_hnsw_ef_search,
using_ugc=self._using_ugc,
**kwargs,
)
self._client.indices.create(index=self._collection_name.lower(), body=mapping)
@@ -309,15 +283,20 @@ class LindormVectorStore(BaseVector):
def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dict:
routing_field = kwargs.get("routing_field")
excludes_from_source = kwargs.get("excludes_from_source")
analyzer = kwargs.get("analyzer", "ik_max_word")
text_field = kwargs.get("text_field", Field.CONTENT_KEY.value)
engine = kwargs["engine"]
shard = kwargs["shards"]
space_type = kwargs["space_type"]
space_type = kwargs.get("space_type")
if space_type is None:
if method_name == "hnsw":
space_type = "l2"
else:
space_type = "cosine"
data_type = kwargs["data_type"]
vector_field = kwargs.get("vector_field", Field.VECTOR.value)
using_ugc = kwargs.get("using_ugc", False)
if method_name == "ivfpq":
ivfpq_m = kwargs["ivfpq_m"]
@@ -366,13 +345,11 @@ def default_text_mapping(dimension: int, method_name: str, **kwargs: Any) -> dic
if excludes_from_source:
mapping["mappings"]["_source"] = {"excludes": excludes_from_source} # e.g. {"excludes": ["vector_field"]}
if method_name == "ivfpq" and routing_field is not None:
if using_ugc and method_name == "ivfpq":
mapping["settings"]["index"]["knn_routing"] = True
mapping["settings"]["index"]["knn.offline.construction"] = True
if method_name == "flat" and routing_field is not None:
elif using_ugc and method_name == "hnsw" or using_ugc and method_name == "flat":
mapping["settings"]["index"]["knn_routing"] = True
return mapping
@@ -386,14 +363,12 @@ def default_text_search_query(
minimum_should_match: int = 0,
filters: Optional[list[dict]] = None,
routing: Optional[str] = None,
routing_field: Optional[str] = None,
**kwargs,
) -> dict:
if routing is not None:
routing_field = kwargs.get("routing_field", "routing_field")
query_clause = {
"bool": {
"must": [{"match": {text_field: query_text}}, {"term": {f"metadata.{routing_field}.keyword": routing}}]
}
"bool": {"must": [{"match": {text_field: query_text}}, {"term": {f"{routing_field}.keyword": routing}}]}
}
else:
query_clause = {"match": {text_field: query_text}}
@@ -449,7 +424,7 @@ def default_vector_search_query(
) -> dict:
if filters is not None:
filter_type = "post_filter" if filter_type is None else filter_type
if not isinstance(filter, list):
if not isinstance(filters, list):
raise RuntimeError(f"unexpected filter with {type(filters)}")
final_ext = {"lvector": {}}
if min_score != "0.0":
@@ -483,16 +458,40 @@ def default_vector_search_query(
class LindormVectorStoreFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> LindormVectorStore:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.LINDORM, collection_name))
lindorm_config = LindormVectorStoreConfig(
hosts=dify_config.LINDORM_URL,
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
)
return LindormVectorStore(collection_name, lindorm_config)
using_ugc = dify_config.USING_UGC_INDEX
routing_value = None
if dataset.index_struct:
if using_ugc:
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"]
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)
index_type = dify_config.DEFAULT_INDEX_TYPE
distance_type = dify_config.DEFAULT_DISTANCE_TYPE
class_prefix = Dataset.gen_collection_name_by_id(dataset.id)
index_struct_dict = {
"type": VectorType.LINDORM,
"vector_store": {"class_prefix": class_prefix},
"index_type": index_type,
"dimension": dimension,
"distance_type": distance_type,
}
dataset.index_struct = json.dumps(index_struct_dict)
if using_ugc:
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}"
routing_value = class_prefix
else:
index_name = class_prefix
return LindormVectorStore(index_name, lindorm_config, routing_value=routing_value)

View File

@@ -162,7 +162,7 @@ class TidbService:
clusters = []
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "FULL"}
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = requests.get(
f"{api_url}/clusters:batchGet", params=params, auth=HTTPDigestAuth(public_key, private_key)
)

View File

@@ -37,8 +37,6 @@ class TiDBVectorConfig(BaseModel):
raise ValueError("config TIDB_VECTOR_PORT is required")
if not values["user"]:
raise ValueError("config TIDB_VECTOR_USER is required")
if not values["password"]:
raise ValueError("config TIDB_VECTOR_PASSWORD is required")
if not values["database"]:
raise ValueError("config TIDB_VECTOR_DATABASE is required")
if not values["program_name"]:

View File

@@ -103,7 +103,7 @@ class ExtractProcessor:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown"}:
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
if is_automatic
@@ -141,7 +141,7 @@ class ExtractProcessor:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(file_path)
elif file_extension in {".md", ".markdown"}:
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)

View File

@@ -270,9 +270,6 @@ class ApiTool(Tool):
elif property["type"] == "object" or property["type"] == "array":
if isinstance(value, str):
try:
# an array str like '[1,2]' also can convert to list [1,2] through json.loads
# json not support single quote, but we can support it
value = value.replace("'", '"')
return json.loads(value)
except ValueError:
return value

View File

@@ -4,6 +4,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
@@ -39,6 +40,8 @@ class WorkflowLoggingCallback(WorkflowCallback):
self.print_text("\n[GraphRunStartedEvent]", color="pink")
elif isinstance(event, GraphRunSucceededEvent):
self.print_text("\n[GraphRunSucceededEvent]", color="green")
elif isinstance(event, GraphRunPartialSucceededEvent):
self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink")
elif isinstance(event, GraphRunFailedEvent):
self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red")
elif isinstance(event, NodeRunStartedEvent):

View File

@@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum):
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
class NodeRunResult(BaseModel):
@@ -43,3 +44,4 @@ class NodeRunResult(BaseModel):
edge_source_handle: Optional[str] = None # source handle id of node with multiple branches
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed

View File

@@ -1,3 +1,4 @@
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
@@ -32,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent):
class GraphRunFailedEvent(BaseGraphEvent):
error: str = Field(..., description="failed reason")
exceptions_count: Optional[int] = Field(description="exception count", default=0)
class GraphRunPartialSucceededEvent(BaseGraphEvent):
exceptions_count: int = Field(..., description="exception count")
outputs: Optional[dict[str, Any]] = None
###########################################
@@ -82,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunExceptionEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
@@ -140,8 +151,8 @@ class BaseIterationEvent(GraphEngineEvent):
class IterationRunStartedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
predecessor_node_id: Optional[str] = None
@@ -153,18 +164,18 @@ class IterationRunNextEvent(BaseIterationEvent):
class IterationRunSucceededEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
iteration_duration_map: Optional[dict[str, float]] = None
class IterationRunFailedEvent(BaseIterationEvent):
start_at: datetime = Field(..., description="start at")
inputs: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
metadata: Optional[dict[str, Any]] = None
inputs: Optional[Mapping[str, Any]] = None
outputs: Optional[Mapping[str, Any]] = None
metadata: Optional[Mapping[str, Any]] = None
steps: int = 0
error: str = Field(..., description="failed reason")

View File

@@ -64,13 +64,21 @@ class Graph(BaseModel):
edge_configs = graph_config.get("edges")
if edge_configs is None:
edge_configs = []
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
edge_configs = cast(list, edge_configs)
node_configs = cast(list, node_configs)
# reorganize edges mapping
edge_mapping: dict[str, list[GraphEdge]] = {}
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
target_edge_ids = set()
fail_branch_source_node_id = [
node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch"
]
for edge_config in edge_configs:
source_node_id = edge_config.get("source")
if not source_node_id:
@@ -90,8 +98,16 @@ class Graph(BaseModel):
# parse run condition
run_condition = None
if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle"))
if edge_config.get("sourceHandle"):
if (
edge_config.get("source") in fail_branch_source_node_id
and edge_config.get("sourceHandle") != "fail-branch"
):
run_condition = RunCondition(type="branch_identify", branch_identify="success-branch")
elif edge_config.get("sourceHandle") != "source":
run_condition = RunCondition(
type="branch_identify", branch_identify=edge_config.get("sourceHandle")
)
graph_edge = GraphEdge(
source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition
@@ -100,13 +116,6 @@ class Graph(BaseModel):
edge_mapping[source_node_id].append(graph_edge)
reverse_edge_mapping[target_node_id].append(graph_edge)
# node configs
node_configs = graph_config.get("nodes")
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}

View File

@@ -15,6 +15,7 @@ class RouteNodeState(BaseModel):
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
EXCEPTION = "exception"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
@@ -51,7 +52,11 @@ class RouteNodeState(BaseModel):
:param run_result: run result
"""
if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}:
if self.status in {
RouteNodeState.Status.SUCCESS,
RouteNodeState.Status.FAILED,
RouteNodeState.Status.EXCEPTION,
}:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@@ -59,6 +64,9 @@ class RouteNodeState(BaseModel):
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION:
self.status = RouteNodeState.Status.EXCEPTION
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")

View File

@@ -5,21 +5,24 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional
from typing import Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
@@ -36,7 +39,9 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
@@ -48,7 +53,12 @@ logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(
self, max_workers=None, thread_name_prefix="", initializer=None, initargs=(), max_submit_count=100
self,
max_workers=None,
thread_name_prefix="",
initializer=None,
initargs=(),
max_submit_count=dify_config.MAX_SUBMIT_COUNT,
) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
@@ -88,7 +98,7 @@ class GraphEngine:
max_execution_time: int,
thread_pool_id: Optional[str] = None,
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_submit_count = dify_config.MAX_SUBMIT_COUNT
thread_pool_max_workers = 10
# init thread pool
@@ -128,6 +138,7 @@ class GraphEngine:
def run(self) -> Generator[GraphEngineEvent, None, None]:
# trigger graph run start event
yield GraphRunStartedEvent()
handle_exceptions = []
try:
if self.init_params.workflow_type == WorkflowType.CHAT:
@@ -140,13 +151,17 @@ class GraphEngine:
)
# run graph
generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))
generator = stream_processor.process(
self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions)
)
for item in generator:
try:
yield item
if isinstance(item, NodeRunFailedEvent):
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.")
yield GraphRunFailedEvent(
error=item.route_node_state.failed_reason or "Unknown error.",
exceptions_count=len(handle_exceptions),
)
return
elif isinstance(item, NodeRunSucceededEvent):
if item.node_type == NodeType.END:
@@ -172,19 +187,24 @@ class GraphEngine:
].strip()
except Exception as e:
logger.exception("Graph run failed")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
return
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
# count exceptions to determine partial success
if len(handle_exceptions) > 0:
yield GraphRunPartialSucceededEvent(
exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs
)
else:
# trigger graph run success event
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
self._release_thread()
except GraphRunFailedError as e:
yield GraphRunFailedEvent(error=e.error)
yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions))
self._release_thread()
return
except Exception as e:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions))
self._release_thread()
raise e
@@ -198,6 +218,7 @@ class GraphEngine:
in_parallel_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
parallel_start_node_id = None
if in_parallel_id:
@@ -242,7 +263,7 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
try:
# run node
generator = self._run_node(
@@ -252,6 +273,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@@ -301,7 +323,12 @@ class GraphEngine:
if len(edge_mappings) == 1:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
if edge.run_condition:
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -334,7 +361,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue
edge = sub_edge_mappings[0]
edge = cast(GraphEdge, sub_edge_mappings[0])
result = ConditionManager.get_condition_handler(
init_params=self.init_params,
@@ -355,6 +382,7 @@ class GraphEngine:
edge_mappings=sub_edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@@ -369,11 +397,18 @@ class GraphEngine:
break
next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=edge_mappings,
in_parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in parallel_generator:
@@ -395,6 +430,7 @@ class GraphEngine:
edge_mappings: list[GraphEdge],
in_parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
@@ -438,6 +474,7 @@ class GraphEngine:
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
@@ -481,6 +518,7 @@ class GraphEngine:
parallel_start_node_id: str,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> None:
"""
Run parallel nodes
@@ -502,6 +540,7 @@ class GraphEngine:
in_parallel_id=parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
@@ -548,6 +587,7 @@ class GraphEngine:
parallel_start_node_id: Optional[str] = None,
parent_parallel_id: Optional[str] = None,
parent_parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent, None, None]:
"""
Run node
@@ -587,19 +627,55 @@ class GraphEngine:
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
@@ -735,6 +811,56 @@ class GraphEngine:
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance
def _handle_continue_on_error(
self,
node_instance: BaseNode[BaseNodeData],
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
"""
handle continue on error when self._should_continue_on_error is True
:param error_result (NodeRunResult): error run result
:param variable_pool (VariablePool): variable pool
:return: excption run result
"""
# add error message and error type to variable pool
variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error)
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
if self.graph.edge_mapping.get(node_instance.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
outputs={
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
return error_result
class GraphRunFailedError(Exception):
def __init__(self, error: str):

Some files were not shown because too many files have changed in this diff Show More