Compare commits

..

107 Commits

Author SHA1 Message Date
Joel
7608eb1049 Merge branch 'main' into feat/plugin-auto-upgrade-fe 2025-07-10 14:20:34 +08:00
github-actions[bot]
b834131f50 chore: translate i18n files (#22132)
Co-authored-by: iamjoel <2120155+iamjoel@users.noreply.github.com>
2025-07-10 14:19:26 +08:00
Joel
5375d9bb27 feat: the frontend part of mcp (#22131)
Co-authored-by: jZonG <jzongcode@gmail.com>
Co-authored-by: Novice <novice12185727@gmail.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
2025-07-10 14:14:02 +08:00
Novice
535fff62f3 feat: add MCP support (#20716)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
2025-07-10 14:01:34 +08:00
baonudesifeizhai
18b58424ec Fix: Resolve issue with json_output (#22053) 2025-07-10 13:34:06 +08:00
Yongtao Huang
10858ea1dc Chore: rm useless import and vars (#22108) 2025-07-10 11:47:43 +08:00
Joel
95ce7b6f47 feat: add time zone 2025-07-10 11:34:05 +08:00
NeatGuyCoding
6f8c7a66c8 feat: add redis fallback mechanism #21043 (#21044)
Co-authored-by: tech <cto@sb>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-10 10:19:58 +08:00
luckylhb90
a371390d6c optimize: batch embedding and qdrant write_consistency_factor parameter (#21776)
Co-authored-by: hobo.l <hobo.l@binance.com>
2025-07-10 10:16:59 +08:00
Wu Tianwei
a316766ad7 chore: Update theme vars (#22113) 2025-07-10 10:11:31 +08:00
Minamiyama
a9cc19f530 feat(question-classifier): add drag-and-drop sorting for topics list (#22066)
Co-authored-by: crazywoola <427733928@qq.com>
2025-07-10 10:03:11 +08:00
Jason Young
881a151d30 test: add comprehensive unit tests for encrypter module (#22102) 2025-07-10 10:01:15 +08:00
NFish
785c4caa67 fix: allow update plugin install settings (#22111) 2025-07-10 09:58:48 +08:00
Heyang Wang
4403bc67a1 fix(Drawer): add overflow hidden to ensure copy button is always clickable (#21992) (#22103)
Co-authored-by: wangheyang <wangheyang@corp.netease.com>
2025-07-10 09:20:02 +08:00
wangsen3
b237113311 Update clean_document_task.py (#22090) 2025-07-10 09:18:50 +08:00
-LAN-
4cb50f1809 feat(libs): Introduce extract_tenant_id (#22086)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2025-07-09 17:45:56 +08:00
Davide Delbianco
1885426421 feat: Allow to change SSL verify in HTTP Node (#22052)
Co-authored-by: crazywoola <427733928@qq.com>
2025-07-09 15:53:24 +08:00
wlleiiwang
89b52471fb Optimize the memory usage of Tencent Vector Database (#22079)
Co-authored-by: wlleiiwang <wlleiiwang@tencent.com>
2025-07-09 15:53:06 +08:00
Minamiyama
3643ed1014 Feat: description field for env variables (#21556) 2025-07-09 15:18:23 +08:00
kurokobo
e39236186d feat: introduce new env ALLOW_UNSAFE_DATA_SCHEME to allow rendering data uri scheme (#21321) 2025-07-09 10:12:40 +08:00
Yongtao Huang
521488f926 Remove tow unused files (#22022) 2025-07-09 09:28:26 +08:00
Jason Young
d61ea5a2de test: add comprehensive unit tests for UrlSigner (#22030) 2025-07-08 21:22:37 +08:00
Davide Delbianco
816210d744 Expose LLM usage in workflows (#21766)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2025-07-08 21:18:00 +08:00
Joel
784a236280 Merge branch 'main' into feat/plugin-auto-upgrade-fe 2025-07-08 17:20:37 +08:00
Joel
1e0426ca6f chore: peroid not auto scroll 2025-07-08 17:15:00 +08:00
Minamiyama
f925869f61 fix(variable): ensure unique variable names in var-list (#22038) 2025-07-08 15:41:27 +08:00
NFish
f62b59a805 don't add search params when opening detail links from marketplace. (#22034) 2025-07-08 15:15:38 +08:00
Minamiyama
a4bdeba60d feat(question-classifier): add instanceId to class-item editor (#22002) 2025-07-08 10:04:05 +08:00
Jason Young
5c0cb7f912 test: add unit tests for password validation and hashing (#22003) 2025-07-08 10:00:00 +08:00
NeatGuyCoding
2ffbf5435d minro fix: fix duplicate local import of ToolProviderType (#22013)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-08 09:49:53 +08:00
Minamiyama
71385d594d fix(variables): Improve getNodeUsedVars implementation details (#21987) 2025-07-08 09:33:13 +08:00
NeatGuyCoding
53c4912cbb feat: add unit tests and validation for aliyun tracing (#22012)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-08 09:32:30 +08:00
NeatGuyCoding
1760179093 minro fix: fix a typo for aliyun (#22001)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-07 22:04:38 +08:00
鸽子
aded30b664 fix: resolve dropdown menu visibility issue caused by z-index conflict (#22000) 2025-07-07 21:58:05 +08:00
Yongtao Huang
de54f8d0ef Chore: remove unreachable code (#21986) 2025-07-07 21:55:34 +08:00
quicksand
5b0b64c7e5 fix: document delete image files check file exist (#21991) 2025-07-07 21:53:40 +08:00
Arcaner
b654c852a5 chore(docker): increase NGINX_CLIENT_MAX_BODY_SIZE from 15M to 100M i… (#21995) 2025-07-07 21:51:49 +08:00
Minamiyama
c48b32c9e3 ENH(ui): enhance check list (#21932)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-07 14:52:36 +08:00
-LAN-
8f723697ef refactor(graph_engine): Take GraphRuntimeState out of GraphEngine (#21882) 2025-07-07 13:15:18 +08:00
mizoo
de22648b9f feat: Add support for type="hidden" input elements in Markdown forms (#21922)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-07-07 10:35:30 +08:00
baonudesifeizhai
b9f56852dc fix: resolve JSON.parse precision issue causing 'list index out of ra… (#21253) 2025-07-07 10:05:54 +08:00
baonudesifeizhai
108cc3486f fix(agent): show agent run steps, fixes #21718 (#21945)
Co-authored-by: crazywoola <427733928@qq.com>
2025-07-07 09:59:47 +08:00
NeatGuyCoding
ac69b8b191 refactor: extract common url validator for config_entity.py (#21934)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-07 09:34:13 +08:00
허재원
8288145ee4 chore(i18n): fix typos and improve Korean translation (#21955) 2025-07-07 09:33:09 +08:00
NeatGuyCoding
51f6095be7 minor fix: translation for pause (#21949)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-05 12:45:29 +08:00
heyszt
a201e9faee feat: Add Aliyun LLM Observability Integration (#21471) 2025-07-04 21:54:33 +08:00
HyaCinth
fec6bafcda refactor(web): Restructure the operation buttons layout in the app information component (#21742) (#21818) 2025-07-04 21:53:21 +08:00
NeatGuyCoding
2639f950cc minor fix: removes the duplicated handling logic for TracingProviderEnum.ARIZE and TracingProviderEnum.PHOENIX from the OpsTraceProviderConfigMap (#21927)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-04 16:46:48 +08:00
Bowen Liang
6663187eca test:add unit test for api version config (#21919) 2025-07-04 15:33:20 +08:00
Nite Knite
13990f31a1 feat: update account menu style (#21916) 2025-07-04 14:52:30 +08:00
GuanMu
de39b737b6 Feat list query (#21907) 2025-07-04 14:18:31 +08:00
GuanMu
a66ed7157e feat: add document pause and resume functionality (#21894) 2025-07-04 14:06:47 +08:00
Ganondorf
c9c49200e0 use repair_json fix json parse error of HTTPRequestNode (#21909)
Co-authored-by: lizb <lizb@sugon.com>
2025-07-04 14:01:17 +08:00
Minamiyama
317d287458 fix(loop-variables): validate variable name input (#21888) 2025-07-03 23:30:56 +08:00
Joel
fd7396d8f9 chore: icon fixed 2025-07-03 17:48:22 +08:00
Joel
a0af33e945 Merge branch 'main' into feat/plugin-auto-upgrade-fe 2025-07-03 17:34:06 +08:00
非法操作
a79f37b686 fix: tts tool must choose a voice (#21877) 2025-07-03 17:10:01 +08:00
baonudesifeizhai
1c7404099d fix: prevent timeout in file encoding detection for large files (#21453)
Co-authored-by: crazywoola <427733928@qq.com>
2025-07-03 17:06:49 +08:00
Joel
ed54bd5121 fix: not search plugin if marketplace enabled (#21880) 2025-07-03 16:43:11 +08:00
GuanMu
06c3deff11 Fix: Add title attribute to edit time text for improved accessibility (#21871) 2025-07-03 16:07:07 +08:00
cutiechi
47954aa284 feat(api): validate and reject external datasets in document update (#21783) 2025-07-03 14:50:53 +08:00
Novice
f3c8625fe2 fix: The statistics page cannot display the tokens consumed by agent node (#21861) 2025-07-03 14:40:47 +08:00
NeatGuyCoding
ebc4fdc4b2 moving the MessageStatus class from the models.model module to models.enums module (#21867)
Signed-off-by: neatguycoding <15627489+NeatGuyCoding@users.noreply.github.com>
2025-07-03 13:56:23 +08:00
Ali Saleh
1af3d40c1a feat: Improve Observability with Arize & Phoenix Integration (#19840)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: Gu <guchenhe@gmail.com>
2025-07-03 13:52:14 +08:00
jiangbo721
31eb8548ef fix: Before publish the app, preview the voice of tts, it raise an er… (#21821)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-07-03 10:53:14 +08:00
Minamiyama
a45aa1e505 feat(variables): auto replace spaces with underscores in variable name inputs (#21843) 2025-07-03 10:36:38 +08:00
Minamiyama
cb0d4a1e15 style(config-var): update styling classes to use design system tokens (#21846) 2025-07-03 10:00:44 +08:00
crazywoola
21e68b9cf1 fix: nodeExtraData might be undefined (#21856) 2025-07-03 09:59:19 +08:00
HyaCinth
a3654c8fe9 fix(web): adjust HTTP node method and input layout (#21834) (#21855) 2025-07-03 09:26:38 +08:00
Yeuoly
980b0188d2 feat(tests): add structured output parser tests for LLM responses (#21838) 2025-07-03 09:10:04 +08:00
Kalo Chin
daab648c78 fix: plugin deamon start fail (#21841) 2025-07-03 09:09:02 +08:00
jiangbo721
e17b33e004 chore: add message status enum (#21825)
Co-authored-by: 刘江波 <jiangbo721@163.com>
2025-07-02 21:22:28 +08:00
Kalo Chin
4e7c9dd2ae feat: Retain llm setting for agent node (#21842) 2025-07-02 20:28:25 +08:00
Yeuoly
5487463385 fix: add list contents handling in structured LLM output (#21837) 2025-07-02 19:14:21 +08:00
Joel
8d8220b06c fix: utc time show 2025-06-30 18:28:09 +08:00
Joel
0625d6a361 fix: not use local time 2025-06-30 18:22:40 +08:00
Joel
63a1a1077e Merge branch 'main' into feat/plugin-auto-upgrade-fe 2025-06-30 14:01:29 +08:00
Joel
0af646d947 fix: fetch installed plugin instead of all plugins 2025-06-27 19:35:18 +08:00
Joel
07c99745fa feat: handle downgrade install 2025-06-27 19:05:12 +08:00
Joel
afd0d31354 fix: not the same as 2025-06-27 12:01:32 +08:00
Joel
18bbf1165d feat: exculde call api 2025-06-27 11:53:14 +08:00
Joel
5f17edc77f feat: downgrade detect 2025-06-27 11:42:28 +08:00
Joel
836027cb33 chore: add auto update show config 2025-06-27 11:36:08 +08:00
Joel
f3cbfe2223 feat: config can save 2025-06-27 10:49:22 +08:00
Joel
bc1e4c88e0 feat: no data placeholder 2025-06-27 10:36:54 +08:00
Joel
d114485abd feat: pluging loading 2025-06-27 10:10:02 +08:00
Joel
3e8a4a66fe feat: api to refernce settings 2025-06-27 09:55:25 +08:00
Joel
4c583f3d9a feat: can select plugins 2025-06-26 15:31:50 +08:00
Joel
52b845a5bb feat: select box setting 2025-06-26 10:48:11 +08:00
Joel
38d1c85c57 main 2025-06-26 10:15:41 +08:00
Joel
c43d992f2b feat: fetch plugin list 2025-06-25 18:40:12 +08:00
Joel
1ff5969b92 feat: select tool template 2025-06-25 17:41:33 +08:00
Joel
93a560ee54 chore: ui and clear 2025-06-25 16:45:21 +08:00
Joel
2f241d932c chore: temp i18n 2025-06-24 16:29:54 +08:00
Joel
a0804786fd feat: downgrade modal i18n 2025-06-24 16:15:26 +08:00
Joel
c6fa8102eb feat: downgrade modal 2025-06-24 15:36:10 +08:00
Joel
7ec5816513 feat: show downgrade warning logic 2025-06-24 11:21:57 +08:00
Joel
825fbcc6f8 feat: auto update button 2025-06-24 11:04:04 +08:00
Joel
ccef71626d feat: show list and select 2025-06-23 18:31:21 +08:00
Joel
29cac85b12 feat: plugin no data 2025-06-23 18:09:32 +08:00
Joel
8b290ac7a1 feat: only choose 15 time 2025-06-23 16:45:48 +08:00
Joel
01cdffaa08 feat: plugins picker holder 2025-06-19 18:13:56 +08:00
Joel
3061280f7a fat: auto update mode 2025-06-19 17:56:53 +08:00
Joel
bc75d810c4 feat: choose time 2025-06-19 17:47:31 +08:00
Joel
dc5e974a78 feat: choose auto update description and i18n 2025-06-19 16:27:17 +08:00
Joel
baff25c160 feat: auto update strategy picker 2025-06-19 16:11:02 +08:00
Joel
42b6524954 feat: type config 2025-06-18 15:04:40 +08:00
1171 changed files with 26788 additions and 41891 deletions

View File

@@ -6,7 +6,6 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
- "deploy/rag-dev"
tags:
- "*"

View File

@@ -4,7 +4,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/rag-dev"
- "deploy/dev"
types:
- completed
@@ -12,13 +12,12 @@ jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/rag-dev'
github.event.workflow_run.conclusion == 'success'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
with:
host: ${{ secrets.RAG_SSH_HOST }}
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -1,3 +1,4 @@
import os
import sys
@@ -16,20 +17,20 @@ else:
# It seems that JetBrains Python debugger does not work well with gevent,
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
# from gevent import monkey
#
# # gevent
# monkey.patch_all()
#
# from grpc.experimental import gevent as grpc_gevent # type: ignore
#
# # grpc gevent
# grpc_gevent.init_gevent()
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
from gevent import monkey
# import psycogreen.gevent # type: ignore
#
# psycogreen.gevent.patch_psycopg()
# gevent
monkey.patch_all()
from grpc.experimental import gevent as grpc_gevent # type: ignore
# grpc gevent
grpc_gevent.init_gevent()
import psycogreen.gevent # type: ignore
psycogreen.gevent.patch_psycopg()
from app_factory import create_app

View File

@@ -222,28 +222,11 @@ class HostedFetchAppTemplateConfig(BaseSettings):
)
class HostedFetchPipelineTemplateConfig(BaseSettings):
"""
Configuration for fetching pipeline templates
"""
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
default="database",
)
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
description="Domain for fetching remote pipeline templates",
default="https://tmpl.dify.ai",
)
class HostedServiceConfig(
# place the configs in alphabet order
HostedAnthropicConfig,
HostedAzureOpenAiConfig,
HostedFetchAppTemplateConfig,
HostedFetchPipelineTemplateConfig,
HostedMinmaxConfig,
HostedOpenAiConfig,
HostedSparkConfig,

View File

@@ -3,7 +3,6 @@ from threading import Lock
from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
if TYPE_CHECKING:
from core.model_runtime.entities.model_entities import AIModelEntity
@@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("datasource_plugin_providers_lock")
)

View File

@@ -56,6 +56,7 @@ from .app import (
conversation,
conversation_variables,
generator,
mcp_server,
message,
model_config,
ops_trace,
@@ -76,6 +77,7 @@ from .billing import billing, compliance
# Import datasets controllers
from .datasets import (
data_source,
datasets,
datasets_document,
datasets_segments,
@@ -84,14 +86,6 @@ from .datasets import (
metadata,
website,
)
from .datasets.rag_pipeline import (
datasource_auth,
datasource_content_preview,
rag_pipeline,
rag_pipeline_datasets,
rag_pipeline_import,
rag_pipeline_workflow,
)
# Import explore controllers
from .explore import (

View File

@@ -90,23 +90,11 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
if text_to_speech is None:
raise ValueError("TTS is not enabled")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
if app_model.app_model_config is None:
raise ValueError("AppModelConfig not found")
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")

View File

@@ -0,0 +1,102 @@
import json
from enum import StrEnum
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from models.model import AppMCPServer
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class AppMCPServerController(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
return server
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def post(self, app_model):
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
args = parser.parse_args()
server = AppMCPServer(
name=app_model.name,
description=args["description"],
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
db.session.commit()
return server
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def put(self, app_model):
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()
server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = args["status"]
db.session.commit()
return server
class AppMCPServerRefreshController(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise NotFound()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")

View File

@@ -283,15 +283,6 @@ class DatasetApi(Resource):
location="json",
help="Invalid external knowledge api id.",
)
parser.add_argument(
"icon_info",
type=dict,
required=False,
nullable=True,
location="json",
help="Invalid icon info.",
)
args = parser.parse_args()
data = request.get_json()

View File

@@ -1,4 +1,3 @@
import json
import logging
from argparse import ArgumentTypeError
from datetime import UTC, datetime
@@ -52,7 +51,6 @@ from fields.document_fields import (
)
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
@@ -663,7 +661,7 @@ class DocumentDetailApi(DocumentResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@@ -1030,41 +1028,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
return {"result": "success"}, 200
class DocumentPipelineExecutionLogApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound("Document not found.")
log = (
db.session.query(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc())
.first()
)
if not log:
return {
"datasource_info": None,
"datasource_type": None,
"input_data": None,
"datasource_node_id": None,
}, 200
return {
"datasource_info": json.loads(log.datasource_info),
"datasource_type": log.datasource_type,
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
api.add_resource(DatasetInitApi, "/datasets/init")
@@ -1087,6 +1050,3 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
api.add_resource(
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
)

View File

@@ -101,9 +101,3 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500
class PipelineNotFoundError(BaseHTTPException):
error_code = "pipeline_not_found"
description = "Pipeline not found."
code = 404

View File

@@ -1,197 +0,0 @@
from flask import redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden, NotFound
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from services.datasource_provider_service import DatasourceProviderService
class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
redirect_url = (
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
)
system_credentials = plugin_oauth_config.system_credentials
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
)
return response.model_dump()
class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
oauth_handler = OAuthHandler()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
system_credentials=plugin_oauth_config.system_credentials,
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
)
return {"result": datasources}, 200
class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
try:
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
"/oauth/plugin/datasource",
)
api.add_resource(
DatasourceOauthCallback,
"/oauth/plugin/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource",
)
api.add_resource(
DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<string:auth_id>",
)

View File

@@ -1,55 +0,0 @@
from flask_restful import ( # type: ignore
Resource, # type: ignore
reqparse,
)
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_user, login_required
from models import Account
from models.dataset import Pipeline
from services.rag_pipeline.rag_pipeline import RagPipelineService
class DataSourceContentPreviewApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_rag_pipeline
def post(self, pipeline: Pipeline, node_id: str):
"""
Run datasource content preview
"""
if not isinstance(current_user, Account):
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("datasource_type", type=str, required=True, location="json")
args = parser.parse_args()
inputs = args.get("inputs")
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()
preview_content = rag_pipeline_service.run_datasource_node_preview(
pipeline=pipeline,
node_id=node_id,
user_inputs=inputs,
account=current_user,
datasource_type=datasource_type,
is_published=True,
)
return preview_content, 200
api.add_resource(
DataSourceContentPreviewApi,
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
)

View File

@@ -1,162 +0,0 @@
import logging
from flask import request
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from controllers.console import api
from controllers.console.wraps import (
account_initialization_required,
enterprise_license_required,
setup_required,
)
from extensions.ext_database import db
from libs.login import login_required
from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class PipelineTemplateListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self):
type = request.args.get("type", default="built-in", type=str)
language = request.args.get("language", default="en-US", type=str)
# get pipeline templates
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
return pipeline_templates, 200
class PipelineTemplateDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def get(self, template_id: str):
type = request.args.get("type", default="built-in", type=str)
rag_pipeline_service = RagPipelineService()
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
return pipeline_template, 200
class CustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def patch(self, template_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def delete(self, template_id: str):
RagPipelineService.delete_customized_pipeline_template(template_id)
return 200
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")
return {"data": template.yaml_content}, 200
class PublishCustomizedPipelineTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@enterprise_license_required
def post(self, pipeline_id: str):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
location="json",
nullable=True,
)
args = parser.parse_args()
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
return {"result": "success"}
api.add_resource(
PipelineTemplateListApi,
"/rag/pipeline/templates",
)
api.add_resource(
PipelineTemplateDetailApi,
"/rag/pipeline/templates/<string:template_id>",
)
api.add_resource(
CustomizedPipelineTemplateApi,
"/rag/pipeline/customized/templates/<string:template_id>",
)
api.add_resource(
PublishCustomizedPipelineTemplateApi,
"/rag/pipelines/<string:pipeline_id>/customized/publish",
)

View File

@@ -1,171 +0,0 @@
from flask_login import current_user # type: ignore # type: ignore
from flask_restful import Resource, marshal, reqparse # type: ignore
from werkzeug.exceptions import Forbidden
import services
from controllers.console import api
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_rate_limit_check,
setup_required,
)
from fields.dataset_fields import dataset_detail_fields
from libs.login import login_required
from models.dataset import DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
def _validate_description_length(description):
if len(description) > 400:
raise ValueError("Description cannot exceed 400 characters.")
return description
class CreateRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
parser.add_argument(
"yaml_content",
type=str,
nullable=False,
required=True,
help="yaml_content is required.",
)
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
try:
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
)
if rag_pipeline_dataset_create_entity.permission == "partial_members":
DatasetPermissionService.update_partial_member_list(
current_user.current_tenant_id,
import_info["dataset_id"],
rag_pipeline_dataset_create_entity.partial_member_list,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
return import_info, 201
class CreateEmptyRagPipelineDatasetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument(
"name",
nullable=False,
required=True,
help="type is required. Name must be between 1 to 40 characters.",
type=_validate_name,
)
parser.add_argument(
"description",
type=str,
nullable=True,
required=False,
default="",
)
parser.add_argument(
"icon_info",
type=dict,
nullable=True,
required=False,
default={},
)
parser.add_argument(
"permission",
type=str,
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
nullable=True,
required=False,
default=DatasetPermissionEnum.ONLY_ME,
)
parser.add_argument(
"partial_member_list",
type=list,
nullable=True,
required=False,
default=[],
)
args = parser.parse_args()
dataset = DatasetService.create_empty_rag_pipeline_dataset(
tenant_id=current_user.current_tenant_id,
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
)
return marshal(dataset, dataset_detail_fields), 201
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")

View File

@@ -1,146 +0,0 @@
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import Resource, marshal_with, reqparse # type: ignore
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.datasets.wraps import get_rag_pipeline
from controllers.console.wraps import (
account_initialization_required,
setup_required,
)
from extensions.ext_database import db
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
from libs.login import login_required
from models import Account
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
class RagPipelineImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("mode", type=str, required=True, location="json")
parser.add_argument("yaml_content", type=str, location="json")
parser.add_argument("yaml_url", type=str, location="json")
parser.add_argument("name", type=str, location="json")
parser.add_argument("description", type=str, location="json")
parser.add_argument("icon_type", type=str, location="json")
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("pipeline_id", type=str, location="json")
args = parser.parse_args()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = cast(Account, current_user)
result = import_service.import_rag_pipeline(
account=account,
import_mode=args["mode"],
yaml_content=args.get("yaml_content"),
yaml_url=args.get("yaml_url"),
pipeline_id=args.get("pipeline_id"),
)
session.commit()
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING.value:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200
class RagPipelineImportConfirmApi(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(pipeline_import_fields)
def post(self, import_id):
# Check user role first
if not current_user.is_editor:
raise Forbidden()
# Create service with session
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = cast(Account, current_user)
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED.value:
return result.model_dump(mode="json"), 400
return result.model_dump(mode="json"), 200
class RagPipelineImportCheckDependenciesApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
@marshal_with(pipeline_import_check_dependencies_fields)
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
return result.model_dump(mode="json"), 200
class RagPipelineExportApi(Resource):
@setup_required
@login_required
@get_rag_pipeline
@account_initialization_required
def get(self, pipeline: Pipeline):
if not current_user.is_editor:
raise Forbidden()
# Add include_secret params
parser = reqparse.RequestParser()
parser.add_argument("include_secret", type=bool, default=False, location="args")
args = parser.parse_args()
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
return {"data": result}, 200
# Import Rag Pipeline
api.add_resource(
RagPipelineImportApi,
"/rag/pipelines/imports",
)
api.add_resource(
RagPipelineImportConfirmApi,
"/rag/pipelines/imports/<string:import_id>/confirm",
)
api.add_resource(
RagPipelineImportCheckDependenciesApi,
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
)
api.add_resource(
RagPipelineExportApi,
"/rag/pipelines/<string:pipeline_id>/exports",
)

View File

@@ -1,43 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
):
def decorator(view_func):
@wraps(view_func)
def decorated_view(*args, **kwargs):
if not kwargs.get("pipeline_id"):
raise ValueError("missing pipeline_id in path parameters")
pipeline_id = kwargs.get("pipeline_id")
pipeline_id = str(pipeline_id)
del kwargs["pipeline_id"]
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)
if not pipeline:
raise PipelineNotFoundError()
kwargs["pipeline"] = pipeline
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -18,7 +18,6 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -79,19 +78,9 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
voice = args.get("voice", None)
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")

View File

@@ -1,6 +1,7 @@
import io
from urllib.parse import urlparse
from flask import send_file
from flask import redirect, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
@@ -9,17 +10,34 @@ from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
def is_valid_url(url: str) -> bool:
if not url:
return False
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
return False
class ToolProviderListApi(Resource):
@setup_required
@login_required
@@ -34,7 +52,7 @@ class ToolProviderListApi(Resource):
req.add_argument(
"type",
type=str,
choices=["builtin", "model", "api", "workflow"],
choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
@@ -613,6 +631,166 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user = current_user
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
tenant_id=user.current_tenant_id,
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
user_id=user.id,
server_identifier=args["server_identifier"],
)
)
@setup_required
@login_required
@account_initialization_required
def put(self):
parser = reqparse.RequestParser()
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:
raise ValueError("Server URL is not valid.")
MCPToolManageService.update_mcp_provider(
tenant_id=current_user.current_tenant_id,
provider_id=args["provider_id"],
server_url=args["server_url"],
name=args["name"],
icon=args["icon"],
icon_type=args["icon_type"],
icon_background=args["icon_background"],
server_identifier=args["server_identifier"],
)
return {"result": "success"}
@setup_required
@login_required
@account_initialization_required
def delete(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
return {"result": "success"}
class ToolMCPAuthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
provider_id = args["provider_id"]
tenant_id = current_user.current_tenant_id
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
try:
with MCPClient(
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
):
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
except MCPError as e:
MCPToolManageService.update_mcp_provider_credentials(
mcp_provider=provider,
credentials={},
authed=False,
)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
class ToolMCPDetailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
user = current_user
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
class ToolMCPListAllApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user = current_user
tenant_id = user.current_tenant_id
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
return [tool.to_dict() for tool in tools]
class ToolMCPUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id):
tenant_id = current_user.current_tenant_id
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
tenant_id=tenant_id,
provider_id=provider_id,
)
return jsonable_encoder(tools)
class ToolMCPCallbackApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
state_key = args["state"]
authorization_code = args["code"]
handle_callback(state_key, authorization_code)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
@@ -647,8 +825,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
# mcp tool provider
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")

View File

@@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return tool_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")

View File

@@ -0,0 +1,8 @@
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
api = ExternalApi(bp)
from . import mcp

104
api/controllers/mcp/mcp.py Normal file
View File

@@ -0,0 +1,104 @@
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import api
from core.app.app_config.entities import VariableEntity
from core.mcp import types
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
class MCPAppApi(Resource):
def post(self, server_code):
def int_or_str(value):
if isinstance(value, (int, str)):
return value
else:
return None
parser = reqparse.RequestParser()
parser.add_argument("jsonrpc", type=str, required=True, location="json")
parser.add_argument("method", type=str, required=True, location="json")
parser.add_argument("params", type=dict, required=False, location="json")
parser.add_argument("id", type=int_or_str, required=False, location="json")
args = parser.parse_args()
request_id = args.get("id")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
)
if server.status != AppMCPServerStatus.ACTIVE:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
app = db.session.query(App).filter(App.id == server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
)
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
)
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
)
try:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
except ValidationError as e:
try:
notification = ClientNotification.model_validate(args)
request = notification
except ValidationError as e:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
)
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")

View File

@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode, EndUser
from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -78,20 +78,9 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response

View File

@@ -211,6 +211,9 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
@@ -301,6 +304,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
if dataset.provider == "external":
raise ValueError("External datasets are not supported.")
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique

View File

@@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from models.model import App, AppMode
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -77,21 +77,9 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
if (
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
and app_model.workflow
and app_model.workflow.features_dict
):
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
voice = args.get("voice") or text_to_speech.get("voice")
else:
try:
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
except Exception:
voice = None
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response

View File

@@ -161,10 +161,14 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
@@ -254,10 +258,14 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}
prompt_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum

View File

@@ -85,7 +85,7 @@ class AgentStrategyEntity(BaseModel):
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
meta_version: Optional[str] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())

View File

@@ -15,10 +15,12 @@ class PluginAgentStrategy(BaseAgentStrategy):
tenant_id: str
declaration: AgentStrategyEntity
meta_version: str | None = None
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity, meta_version: str | None):
self.tenant_id = tenant_id
self.declaration = declaration
self.meta_version = meta_version
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters

View File

@@ -113,9 +113,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: Optional[int] = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@@ -128,16 +128,6 @@ class VariableEntity(BaseModel):
return v or []
class RagPipelineVariableEntity(VariableEntity):
"""
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
belong_to_node_id: str
class ExternalDataVariableEntity(BaseModel):
"""
External Data Variable Entity.
@@ -295,7 +285,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: Optional[AppAdditionalFeatures] = None
additional_features: AppAdditionalFeatures
variables: list[VariableEntity] = []
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None

View File

@@ -1,4 +1,4 @@
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import VariableEntity
from models.workflow import Workflow
@@ -20,19 +20,3 @@ class WorkflowVariablesConfigManager:
variables.append(VariableEntity.model_validate(variable))
return variables
@classmethod
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
"""
Convert workflow start variables to variables
:param workflow: workflow instance
"""
variables = []
user_input_form = workflow.rag_pipeline_user_input_form()
# variables
for variable in user_input_form:
variables.append(RagPipelineVariableEntity.model_validate(variable))
return variables

View File

@@ -43,13 +43,11 @@ from core.app.entities.task_entities import (
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.datasource.entities import DatasourceNodeData
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
@@ -185,14 +183,6 @@ class WorkflowResponseConverter:
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
elif event.node_type == NodeType.DATASOURCE:
node_data = cast(DatasourceNodeData, event.node_data)
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
f"{node_data.plugin_id}/{node_data.provider_name}",
)
response.data.extras["icon"] = provider_entity.declaration.identity.icon
return response

View File

@@ -1,95 +0,0 @@
from collections.abc import Generator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
AppStreamResponse,
ErrorStreamResponse,
NodeFinishStreamResponse,
NodeStartStreamResponse,
PingStreamResponse,
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
)
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response
:return:
"""
return cls.convert_blocking_full_response(blocking_response)
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response
:return:
"""
for chunk in stream_response:
chunk = cast(WorkflowAppStreamResponse, chunk)
sub_stream_response = chunk.stream_response
if isinstance(sub_stream_response, PingStreamResponse):
yield "ping"
continue
response_chunk = {
"event": sub_stream_response.event.value,
"workflow_run_id": chunk.workflow_run_id,
}
if isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.to_dict())
yield response_chunk

View File

@@ -1,64 +0,0 @@
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
from models.dataset import Pipeline
from models.model import AppMode
from models.workflow import Workflow
class PipelineConfig(WorkflowUIBasedAppConfig):
"""
Pipeline Config Entity.
"""
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
pass
class PipelineConfigManager(BaseAppConfigManager):
@classmethod
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
pipeline_config = PipelineConfig(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
app_mode=AppMode.RAG_PIPELINE,
workflow_id=workflow.id,
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
)
return pipeline_config
@classmethod
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config
:param tenant_id: tenant id
:param config: app model config args
:param only_structure_validate: only validate the structure of the config
"""
related_config_keys = []
# file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
related_config_keys.extend(current_related_config_keys)
# text_to_speech
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
related_config_keys.extend(current_related_config_keys)
# moderation validation
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
)
related_config_keys.extend(current_related_config_keys)
related_config_keys = list(set(related_config_keys))
# Filter out extra parameters
filtered_config = {key: config.get(key) for key in related_config_keys}
return filtered_config

View File

@@ -1,621 +0,0 @@
import contextvars
import datetime
import json
import logging
import secrets
import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from libs.flask_utils import preserve_flask_contexts
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.model import AppMode
from services.dataset_service import DocumentService
logger = logging.getLogger(__name__)
class PipelineGenerator(BaseAppGenerator):
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Mapping[str, Any]: ...
@overload
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
def generate(
self,
*,
pipeline: Pipeline,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(
pipeline=pipeline,
workflow=workflow,
)
# Add null check for dataset
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
documents = []
if invoke_from == InvokeFrom.PUBLISHED:
for datasource_info in datasource_info_list:
position = DocumentService.get_documents_position(dataset.id)
document = self._build_document(
tenant_id=pipeline.tenant_id,
dataset_id=dataset.id,
built_in_field_enabled=dataset.built_in_field_enabled,
datasource_type=datasource_type,
datasource_info=datasource_info,
created_from="rag-pipeline",
position=position,
account=user,
batch=batch,
document_form=dataset.chunk_structure,
)
db.session.add(document)
documents.append(document)
db.session.commit()
# run in child thread
for i, datasource_info in enumerate(datasource_info_list):
workflow_run_id = str(uuid.uuid4())
document_id = None
if invoke_from == InvokeFrom.PUBLISHED:
document_id = documents[i].id
document_pipeline_execution_log = DocumentPipelineExecutionLog(
document_id=document_id,
datasource_type=datasource_type,
datasource_info=json.dumps(datasource_info),
datasource_node_id=start_node_id,
input_data=inputs,
pipeline_id=pipeline.id,
created_by=user.id,
)
db.session.add(document_pipeline_execution_log)
db.session.commit()
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=datasource_type,
datasource_info=datasource_info,
dataset_id=dataset.id,
start_node_id=start_node_id,
batch=batch,
document_id=document_id,
inputs=self._prepare_user_inputs(
user_inputs=inputs,
variables=pipeline_config.rag_pipeline_variables,
tenant_id=pipeline.tenant_id,
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
),
files=[],
user_id=user.id,
stream=streaming,
invoke_from=invoke_from,
call_depth=call_depth,
workflow_execution_id=workflow_run_id,
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
if invoke_from == InvokeFrom.DEBUGGER:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
if invoke_from == InvokeFrom.DEBUGGER:
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
context=contextvars.copy_context(),
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
else:
# run in child thread
context = contextvars.copy_context()
worker_thread = threading.Thread(
target=self._generate,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"pipeline": pipeline,
"workflow_id": workflow.id,
"user": user,
"application_generate_entity": application_generate_entity,
"invoke_from": invoke_from,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"streaming": streaming,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
# return batch, dataset, documents
return {
"batch": batch,
"dataset": PipelineDataset(
id=dataset.id,
name=dataset.name,
description=dataset.description,
chunk_structure=dataset.chunk_structure,
).model_dump(),
"documents": [
PipelineDocument(
id=document.id,
position=document.position,
data_source_type=document.data_source_type,
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
name=document.name,
indexing_status=document.indexing_status,
error=document.error,
enabled=document.enabled,
).model_dump()
for document in documents
],
}
def _generate(
self,
*,
flask_app: Flask,
context: contextvars.Context,
pipeline: Pipeline,
workflow_id: str,
user: Union[Account, EndUser],
application_generate_entity: RagPipelineGenerateEntity,
invoke_from: InvokeFrom,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
:param pipeline: Pipeline
:param workflow: Workflow
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
app_mode=AppMode.RAG_PIPELINE,
)
context = contextvars.copy_context()
# new thread
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"queue_manager": queue_manager,
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread.start()
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
def single_iteration_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# init application generate entity - use RagPipelineGenerateEntity instead
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
dataset_id=dataset.id,
batch=args.get("batch", ""),
document_id=args.get("document_id"),
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow_id=workflow.id,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
def single_loop_generate(
self,
pipeline: Pipeline,
workflow: Workflow,
node_id: str,
user: Account | EndUser,
args: Mapping[str, Any],
streaming: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
"""
Generate App response.
:param app_model: App
:param workflow: Workflow
:param node_id: the node id
:param user: account or end user
:param args: request args
:param streaming: is streamed
"""
if not node_id:
raise ValueError("node_id is required")
if args.get("inputs") is None:
raise ValueError("inputs is required")
dataset = pipeline.dataset
if not dataset:
raise ValueError("Pipeline dataset is required")
# convert to app config
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
# init application generate entity
application_generate_entity = RagPipelineGenerateEntity(
task_id=str(uuid.uuid4()),
app_config=pipeline_config,
pipeline_config=pipeline_config,
datasource_type=args.get("datasource_type", ""),
datasource_info=args.get("datasource_info", {}),
batch=args.get("batch", ""),
document_id=args.get("document_id"),
dataset_id=dataset.id,
inputs={},
files=[],
user_id=user.id,
stream=streaming,
invoke_from=InvokeFrom.DEBUGGER,
extras={"auto_generate_conversation_name": False},
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
workflow_execution_id=str(uuid.uuid4()),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
# Create workflow node execution repository
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
return self._generate(
flask_app=current_app._get_current_object(), # type: ignore
pipeline=pipeline,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
def _generate_worker(
self,
flask_app: Flask,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# workflow app
runner = PipelineRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
)
runner.run()
except GenerateTaskStoppedError:
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e:
logger.exception("Validation Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except ValueError as e:
if dify_config.DEBUG:
logger.exception("Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,
application_generate_entity: RagPipelineGenerateEntity,
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
Handle response.
:param application_generate_entity: application generate entity
:param workflow: workflow
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
)
try:
return generate_task_pipeline.process()
except ValueError as e:
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError()
else:
logger.exception(
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
)
raise e
def _build_document(
self,
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
account: Union[Account, EndUser],
batch: str,
document_form: str,
):
if datasource_type == "local_file":
name = datasource_info["name"]
elif datasource_type == "online_document":
name = datasource_info["page"]["page_name"]
elif datasource_type == "website_crawl":
name = datasource_info["title"]
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=position,
data_source_type=datasource_type,
data_source_info=json.dumps(datasource_info),
batch=batch,
name=name,
created_from=created_from,
created_by=account.id,
doc_form=document_form,
)
doc_metadata = {}
if built_in_field_enabled:
doc_metadata = {
BuiltInField.document_name: name,
BuiltInField.uploader: account.name,
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
BuiltInField.source: datasource_type,
}
if doc_metadata:
document.doc_metadata = doc_metadata
return document

View File

@@ -1,44 +0,0 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueErrorEvent,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
class PipelineQueueManager(AppQueueManager):
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
super().__init__(task_id, user_id, invoke_from)
self._app_mode = app_mode
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
"""
Publish event to queue
:param event:
:param pub_from:
:return:
"""
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
self._q.put(message)
if isinstance(
event,
QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedError()

View File

@@ -1,221 +0,0 @@
import logging
from collections.abc import Mapping
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
RagPipelineGenerateEntity,
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.dataset import Pipeline
from models.enums import UserFrom
from models.model import EndUser
from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
class PipelineRunner(WorkflowBasedAppRunner):
"""
Pipeline Application Runner
"""
def __init__(
self,
application_generate_entity: RagPipelineGenerateEntity,
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None:
"""
Run application
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
# if only single iteration run is requested
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = {
SystemVariableKey.FILES: files,
SystemVariableKey.USER_ID: user_id,
SystemVariableKey.APP_ID: app_config.app_id,
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
SystemVariableKey.BATCH: self.application_generate_entity.batch,
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
}
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")
) and rag_pipeline_variable.variable in inputs:
rag_pipeline_variables.append(
RAGPipelineVariableInput(
variable=rag_pipeline_variable,
value=inputs[rag_pipeline_variable.variable],
)
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=[],
rag_pipeline_variables=rag_pipeline_variables,
)
# init graph
graph = self._init_rag_pipeline_graph(
graph_config=workflow.graph_dict,
start_node_id=self.application_generate_entity.start_node_id,
)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id,
)
generator = workflow_entry.run(callbacks=workflow_callbacks)
for event in generator:
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
"""
Init pipeline graph
"""
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
nodes = graph_config.get("nodes", [])
edges = graph_config.get("edges", [])
real_run_nodes = []
real_edges = []
exclude_node_ids = []
for node in nodes:
node_id = node.get("id")
node_type = node.get("data", {}).get("type", "")
if node_type == "datasource":
if start_node_id != node_id:
exclude_node_ids.append(node_id)
continue
real_run_nodes.append(node)
for edge in edges:
if edge.get("source") in exclude_node_ids:
continue
real_edges.append(edge)
graph_config = dict(graph_config)
graph_config["nodes"] = real_run_nodes
graph_config["edges"] = real_edges
# init graph
graph = Graph.init(graph_config=graph_config)
if not graph:
raise ValueError("graph not found in workflow")
return graph

View File

@@ -36,7 +36,6 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
PUBLISHED = "published"
@classmethod
def value_of(cls, value: str):
@@ -241,38 +240,3 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
"""
RAG Pipeline Application Generate Entity.
"""
# pipeline config
pipeline_config: WorkflowUIBasedAppConfig
datasource_type: str
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: Optional[str] = None
start_node_id: Optional[str] = None
class SingleIterationRunEntity(BaseModel):
"""
Single Iteration Run Entity.
"""
node_id: str
inputs: dict
single_iteration_run: Optional[SingleIterationRunEntity] = None
class SingleLoopRunEntity(BaseModel):
"""
Single Loop Run Entity.
"""
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None

View File

@@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
from models.enums import MessageStatus
from models.model import Message
logger = logging.getLogger(__name__)
@@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
return err
err_desc = self._error_to_desc(err)
message.status = "error"
message.status = MessageStatus.ERROR
message.error = err_desc
return err

View File

@@ -105,14 +105,6 @@ class DifyAgentCallbackHandler(BaseModel):
self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text(
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
color=self.color,
)
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""

View File

@@ -1,33 +0,0 @@
from abc import ABC, abstractmethod
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class DatasourcePlugin(ABC):
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
) -> None:
self.entity = entity
self.runtime = runtime
@abstractmethod
def datasource_provider_type(self) -> str:
"""
returns the type of the datasource provider
"""
return DatasourceProviderType.LOCAL_FILE
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
)

View File

@@ -1,118 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.entities.provider_entities import ProviderConfig
from core.plugin.impl.tool import PluginToolManager
from core.tools.errors import ToolProviderCredentialValidationError
class DatasourcePluginProviderController(ABC):
entity: DatasourceProviderEntityWithPlugin
tenant_id: str
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
@property
def need_credentials(self) -> bool:
"""
returns whether the provider needs credentials
:return: whether the provider needs credentials
"""
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
manager = PluginToolManager()
if not manager.validate_datasource_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
provider=self.entity.identity.name,
credentials=credentials,
):
raise ToolProviderCredentialValidationError("Invalid credentials")
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
@abstractmethod
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
"""
return datasource with given name
"""
pass
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
"""
validate the format of the credentials of the provider and set the default value if needed
:param credentials: the credentials of the tool
"""
credentials_schema = dict[str, ProviderConfig]()
if credentials_schema is None:
return
for credential in self.entity.credentials_schema:
credentials_schema[credential.name] = credential
credentials_need_to_validate: dict[str, ProviderConfig] = {}
for credential_name in credentials_schema:
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
for credential_name in credentials:
if credential_name not in credentials_need_to_validate:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} not found in provider {self.entity.identity.name}"
)
# check type
credential_schema = credentials_need_to_validate[credential_name]
if not credential_schema.required and credentials[credential_name] is None:
continue
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
elif credential_schema.type == ProviderConfig.Type.SELECT:
if not isinstance(credentials[credential_name], str):
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
options = credential_schema.options
if not isinstance(options, list):
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
if credentials[credential_name] not in [x.value for x in options]:
raise ToolProviderCredentialValidationError(
f"credential {credential_name} should be one of {options}"
)
credentials_need_to_validate.pop(credential_name)
for credential_name in credentials_need_to_validate:
credential_schema = credentials_need_to_validate[credential_name]
if credential_schema.required:
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
# the credential is not set currently, set the default value if needed
if credential_schema.default is not None:
default_value = credential_schema.default
# parse default value into the correct type
if credential_schema.type in {
ProviderConfig.Type.SECRET_INPUT,
ProviderConfig.Type.TEXT_INPUT,
ProviderConfig.Type.SELECT,
}:
default_value = str(default_value)
credentials[credential_name] = default_value

View File

@@ -1,36 +0,0 @@
from typing import Any, Optional
from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
class DatasourceRuntime(BaseModel):
"""
Meta data of a datasource call processing
"""
tenant_id: str
datasource_id: Optional[str] = None
invoke_from: Optional[InvokeFrom] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
class FakeDatasourceRuntime(DatasourceRuntime):
"""
Fake datasource runtime for testing
"""
def __init__(self):
super().__init__(
tenant_id="fake_tenant_id",
datasource_id="fake_datasource_id",
invoke_from=InvokeFrom.DEBUGGER,
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
credentials={},
runtime_parameters={},
)

View File

@@ -1,244 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.enums import CreatorUserRole
from models.model import MessageFile, UploadFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class DatasourceFileManager:
@staticmethod
def sign_file(datasource_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"datasources/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=present_filename,
size=len(file_binary),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(file_binary).hexdigest(),
source_url="",
)
db.session.add(upload_file)
db.session.commit()
db.session.refresh(upload_file)
return upload_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
) -> UploadFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type=dify_config.STORAGE_TYPE,
key=filepath,
name=filename,
size=len(blob),
extension=extension,
mime_type=mimetype,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=user_id,
used=False,
hash=hashlib.sha3_256(blob).hexdigest(),
source_url=file_url,
)
db.session.add(upload_file)
db.session.commit()
return upload_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == id,
)
.first()
)
if not upload_file:
return None
blob = storage.load_once(upload_file.key)
return blob, upload_file.mime_type
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_upload_file_id(upload_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == upload_file_id,
)
.first()
)
if not upload_file:
return None, None
stream = storage.load_stream(upload_file.key)
return stream, upload_file.mime_type
# init tool_file_parser
# from core.file.datasource_file_parser import datasource_file_manager
#
# datasource_file_manager["manager"] = DatasourceFileManager

View File

@@ -1,100 +0,0 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider_id] = controller
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
datasource_type: DatasourceProviderType,
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)

View File

@@ -1,71 +0,0 @@
from typing import Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.datasource.entities.datasource_entities import DatasourceParameter
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
class DatasourceApiEntity(BaseModel):
author: str
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[DatasourceParameter]] = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
class DatasourceProviderApiEntity(BaseModel):
id: str
author: str
name: str # identifier
description: I18nObject
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@field_validator("datasources", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
# -------------
# overwrite datasource parameter types for temp fix
datasources = jsonable_encoder(self.datasources)
for datasource in datasources:
if datasource.get("parameters"):
for parameter in datasource.get("parameters"):
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
# -------------
return {
"id": self.id,
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,
"datasources": datasources,
"labels": self.labels,
}

View File

@@ -1,23 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class I18nObject(BaseModel):
"""
Model class for i18n object.
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@@ -1,361 +0,0 @@
import enum
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.oauth import OAuthSchema
from core.plugin.entities.parameters import (
PluginParameter,
PluginParameterOption,
PluginParameterType,
as_normal_type,
cast_parameter_value,
init_frontend_parameter,
)
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
class DatasourceProviderType(enum.StrEnum):
"""
Enum class for datasource provider
"""
ONLINE_DOCUMENT = "online_document"
LOCAL_FILE = "local_file"
WEBSITE_CRAWL = "website_crawl"
ONLINE_DRIVE = "online_drive"
@classmethod
def value_of(cls, value: str) -> "DatasourceProviderType":
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f"invalid mode value {value}")
class DatasourceParameter(PluginParameter):
"""
Overrides type
"""
class DatasourceParameterType(enum.StrEnum):
"""
removes TOOLS_SELECTOR from PluginParameterType
"""
STRING = PluginParameterType.STRING.value
NUMBER = PluginParameterType.NUMBER.value
BOOLEAN = PluginParameterType.BOOLEAN.value
SELECT = PluginParameterType.SELECT.value
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
FILE = PluginParameterType.FILE.value
FILES = PluginParameterType.FILES.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
def as_normal_type(self):
return as_normal_type(self)
def cast_value(self, value: Any):
return cast_parameter_value(self, value)
type: DatasourceParameterType = Field(..., description="The type of the parameter")
description: I18nObject = Field(..., description="The description of the parameter")
@classmethod
def get_simple_instance(
cls,
name: str,
typ: DatasourceParameterType,
required: bool,
options: Optional[list[str]] = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
:param name: the name of the parameter
:param llm_description: the description presented to the LLM
:param typ: the type of the parameter
:param required: if the parameter is required
:param options: the options of the parameter
"""
# convert options to ToolParameterOption
# FIXME fix the type error
if options:
option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
for option in options
]
else:
option_objs = []
return cls(
name=name,
label=I18nObject(en_US="", zh_Hans=""),
placeholder=None,
type=typ,
required=required,
options=option_objs,
description=I18nObject(en_US="", zh_Hans=""),
)
def init_frontend_parameter(self, value: Any):
return init_frontend_parameter(self, self.type, value)
class DatasourceIdentity(BaseModel):
author: str = Field(..., description="The author of the datasource")
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
@field_validator("parameters", mode="before")
@classmethod
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
return v or []
class DatasourceProviderIdentity(BaseModel):
author: str = Field(..., description="The author of the tool")
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
description="The tags of the tool",
)
class DatasourceProviderEntity(BaseModel):
"""
Datasource provider entity
"""
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
provider_type: DatasourceProviderType
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
datasources: list[DatasourceEntity] = Field(default_factory=list)
class DatasourceInvokeMeta(BaseModel):
"""
Datasource invoke meta
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
"""
Get an empty instance of DatasourceInvokeMeta
"""
return cls(time_cost=0.0, error=None, tool_config={})
@classmethod
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
"""
Get an instance of DatasourceInvokeMeta with error
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
return {
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
}
class DatasourceLabel(BaseModel):
"""
Datasource label
"""
name: str = Field(..., description="The name of the tool")
label: I18nObject = Field(..., description="The label of the tool")
icon: str = Field(..., description="The icon of the tool")
class DatasourceInvokeFrom(Enum):
"""
Enum class for datasource invoke
"""
RAG_PIPELINE = "rag_pipeline"
class OnlineDocumentPage(BaseModel):
"""
Online document page
"""
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
"""
Online document info
"""
workspace_id: str = Field(..., description="The workspace id")
workspace_name: str = Field(..., description="The workspace name")
workspace_icon: str = Field(..., description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
class OnlineDocumentPagesMessage(BaseModel):
"""
Get online document pages response
"""
result: list[OnlineDocumentInfo]
class GetOnlineDocumentPageContentRequest(BaseModel):
"""
Get online document page content request
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
type: str = Field(..., description="The type of the page")
class OnlineDocumentPageContent(BaseModel):
"""
Online document page content
"""
workspace_id: str = Field(..., description="The workspace id")
page_id: str = Field(..., description="The page id")
content: str = Field(..., description="The content of the page")
class GetOnlineDocumentPageContentResponse(BaseModel):
"""
Get online document page content response
"""
result: OnlineDocumentPageContent
class GetWebsiteCrawlRequest(BaseModel):
"""
Get website crawl request
"""
crawl_parameters: dict = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
source_url: str = Field(..., description="The url of the website")
content: str = Field(..., description="The content of the website")
title: str = Field(..., description="The title of the website")
description: str = Field(..., description="The description of the website")
class WebSiteInfo(BaseModel):
"""
Website info
"""
status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
"""
Get website crawl response
"""
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
class DatasourceMessage(ToolInvokeMessage):
pass
#########################
# Online driver file
#########################
class OnlineDriveFile(BaseModel):
"""
Online driver file
"""
key: str = Field(..., description="The key of the file")
size: int = Field(..., description="The size of the file")
class OnlineDriveFileBucket(BaseModel):
"""
Online driver file bucket
"""
bucket: Optional[str] = Field(None, description="The bucket of the file")
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
is_truncated: bool = Field(False, description="Whether the bucket has more files")
class OnlineDriveBrowseFilesRequest(BaseModel):
"""
Get online driver file list request
"""
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
bucket: Optional[str] = Field(None, description="Storage bucket name")
max_keys: int = Field(20, description="Maximum number of files to return")
start_after: Optional[str] = Field(
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
)
class OnlineDriveBrowseFilesResponse(BaseModel):
"""
Get online driver file list response
"""
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
class OnlineDriveDownloadFileRequest(BaseModel):
"""
Get online driver file
"""
key: str = Field(..., description="The name of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")

View File

@@ -1,37 +0,0 @@
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
class DatasourceProviderNotFoundError(ValueError):
pass
class DatasourceNotFoundError(ValueError):
pass
class DatasourceParameterValidationError(ValueError):
pass
class DatasourceProviderCredentialValidationError(ValueError):
pass
class DatasourceNotSupportedError(ValueError):
pass
class DatasourceInvokeError(ValueError):
pass
class DatasourceApiSchemaError(ValueError):
pass
class DatasourceEngineInvokeError(Exception):
meta: DatasourceInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta
super().__init__(**kwargs)

View File

@@ -1,28 +0,0 @@
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
)
class LocalFileDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def datasource_provider_type(self) -> str:
return DatasourceProviderType.LOCAL_FILE

View File

@@ -1,56 +0,0 @@
from typing import Any
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.LOCAL_FILE
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider
"""
pass
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return LocalFileDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@@ -1,73 +0,0 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDocumentPagesMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_online_document_pages(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_pages(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def get_online_document_page_content(
self,
user_id: str,
datasource_parameters: GetOnlineDocumentPageContentRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_online_document_page_content(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DOCUMENT

View File

@@ -1,48 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DOCUMENT
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDocumentDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@@ -1,73 +0,0 @@
from collections.abc import Generator
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceMessage,
DatasourceProviderType,
OnlineDriveBrowseFilesRequest,
OnlineDriveBrowseFilesResponse,
OnlineDriveDownloadFileRequest,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def online_drive_browse_files(
self,
user_id: str,
request: OnlineDriveBrowseFilesRequest,
provider_type: str,
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_browse_files(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def online_drive_download_file(
self,
user_id: str,
request: OnlineDriveDownloadFileRequest,
provider_type: str,
) -> Generator[DatasourceMessage, None, None]:
manager = PluginDatasourceManager()
return manager.online_drive_download_file(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
request=request,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.ONLINE_DRIVE

View File

@@ -1,48 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.ONLINE_DRIVE
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return OnlineDriveDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@@ -1,265 +0,0 @@
from copy import deepcopy
from typing import Any
from pydantic import BaseModel
from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: list[BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy data
"""
return deepcopy(data)
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
data[field_name] = encrypted
return data
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = (
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
)
else:
data[field_name] = "*" * len(data[field_name])
return data
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
data = self._deep_copy(data)
# get fields need to be decrypted
fields = dict[str, BasicProviderConfig]()
for credential in self.config:
fields[credential.name] = credential
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in data:
try:
# if the value is None or empty string, skip decrypt
if not data[field_name]:
continue
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except Exception:
pass
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_type}.{self.provider_identity}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: ToolProviderType
identity_id: str
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name
self.provider_type = provider_type
self.identity_id = identity_id
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return deepcopy(parameters)
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.entity.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters()
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except Exception:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type.value}.{self.provider_name}",
tool_name=self.tool_runtime.entity.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id,
)
cache.delete()

View File

@@ -1,121 +0,0 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension
from typing import Optional
from core.datasource.datasource_file_manager import DatasourceFileManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
class DatasourceFileMessageTransformer:
@classmethod
def transform_datasource_invoke_messages(
cls,
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
"""
for message in messages:
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
yield message
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
message.message, DatasourceMessage.TextMessage
):
# try to download image
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
file = DatasourceFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(
text=f"Failed to download image: {message.message.text}: {e}"
),
meta=message.meta.copy() if message.meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get mime type and save blob to storage
meta = message.meta or {}
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("file_name", None)
# if message is str, encode it to bytes
if not isinstance(message.message, DatasourceMessage.BlobMessage):
raise ValueError("unexpected message type")
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
file = DatasourceFileManager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message.blob,
mimetype=mimetype,
filename=filename,
)
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield message
else:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@@ -1,389 +0,0 @@
import re
import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi["servers"][0]["url"]
request_env = request.headers.get("X-Request-Env")
if request_env:
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
server_url = matched_servers[0] if matched_servers else server_url
# list all interfaces
interfaces = []
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
tool_parameter.type = typ
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if "schema" not in content:
continue
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
if typ:
tool.type = typ
parameters.append(tool)
# check if parameters is duplicated
parameters_count = {}
for parameter in parameters:
if parameter.name not in parameters_count:
parameters_count[parameter.name] = 0
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ in {"integer", "number"}:
return ToolParameter.ToolParameterType.NUMBER
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
elif typ == "array":
items = parameter.get("items") or parameter.get("schema", {}).get("items")
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
:param yaml: the yaml string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
warning = warning or {}
"""
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
for method, operation in path_item.items():
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict | None = None, warning: dict | None = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
:param json: the json string
:param extra_info: the extra info
:param warning: the warning message
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except JSONDecodeError:
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict | None = None, warning: dict | None = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:param extra_info: the extra info
:param warning: the warning message
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
content = content.strip()
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
json_error = e
if loaded_content is None:
try:
loaded_content = safe_load(content)
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
f" yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
f" openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@@ -1,17 +0,0 @@
import re
def remove_leading_symbols(text: str) -> str:
"""
Remove leading punctuation or symbols from the given text.
Args:
text (str): The input text to process.
Returns:
str: The text with leading punctuation or symbols removed.
"""
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text)

View File

@@ -1,9 +0,0 @@
import uuid
def is_valid_uuid(uuid_str: str) -> bool:
try:
uuid.UUID(uuid_str)
return True
except Exception:
return False

View File

@@ -1,43 +0,0 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
):
"""
check is synced
raise ValueError if not synced
"""
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError("parameter configuration mismatch, please republish the tool to update")

View File

@@ -1,35 +0,0 @@
import logging
from pathlib import Path
from typing import Any
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
"""
Safe loading a YAML file
:param file_path: the path of the YAML file
:param ignore_error:
if True, return default_value if error occurs and the error will be logged in debug level
if False, raise error if error occurs
:param default_value: the value returned when errors ignored
:return: an object of the YAML content
"""
if not file_path or not Path(file_path).exists():
if ignore_error:
return default_value
else:
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content or default_value
except Exception as e:
if ignore_error:
return default_value
else:
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e

View File

@@ -1,53 +0,0 @@
from collections.abc import Generator, Mapping
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import (
DatasourceEntity,
DatasourceProviderType,
WebsiteCrawlMessage,
)
from core.plugin.impl.datasource import PluginDatasourceManager
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
tenant_id: str
icon: str
plugin_unique_identifier: str
entity: DatasourceEntity
runtime: DatasourceRuntime
def __init__(
self,
entity: DatasourceEntity,
runtime: DatasourceRuntime,
tenant_id: str,
icon: str,
plugin_unique_identifier: str,
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
def get_website_crawl(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
provider_type: str,
) -> Generator[WebsiteCrawlMessage, None, None]:
manager = PluginDatasourceManager()
return manager.get_website_crawl(
tenant_id=self.tenant_id,
user_id=user_id,
datasource_provider=self.entity.identity.provider,
datasource_name=self.entity.identity.name,
credentials=self.runtime.credentials,
datasource_parameters=datasource_parameters,
provider_type=provider_type,
)
def datasource_provider_type(self) -> str:
return DatasourceProviderType.WEBSITE_CRAWL

View File

@@ -1,52 +0,0 @@
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.__base.datasource_runtime import DatasourceRuntime
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
entity: DatasourceProviderEntityWithPlugin
plugin_id: str
plugin_unique_identifier: str
def __init__(
self,
entity: DatasourceProviderEntityWithPlugin,
plugin_id: str,
plugin_unique_identifier: str,
tenant_id: str,
) -> None:
super().__init__(entity, tenant_id)
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> DatasourceProviderType:
"""
returns the type of the provider
"""
return DatasourceProviderType.WEBSITE_CRAWL
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
"""
return datasource with given name
"""
datasource_entity = next(
(
datasource_entity
for datasource_entity in self.entity.datasources
if datasource_entity.identity.name == datasource_name
),
None,
)
if not datasource_entity:
raise ValueError(f"Datasource with name {datasource_name} not found")
return WebsiteCrawlDatasourcePlugin(
entity=datasource_entity,
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)

View File

@@ -17,27 +17,3 @@ class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: Optional[dict] = None
name: str
indexing_status: str
error: Optional[str] = None
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

View File

@@ -21,6 +21,9 @@ class CommonParameterType(StrEnum):
DYNAMIC_SELECT = "dynamic-select"
# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
class AppSelectorScope(StrEnum):

View File

@@ -1,15 +0,0 @@
from typing import TYPE_CHECKING, Any, cast
from core.datasource import datasource_file_manager
from core.datasource.datasource_file_manager import DatasourceFileManager
if TYPE_CHECKING:
from core.datasource.datasource_file_manager import DatasourceFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
class DatasourceFileParser:
@staticmethod
def get_datasource_file_manager() -> "DatasourceFileManager":
return cast("DatasourceFileManager", datasource_file_manager["manager"])

View File

@@ -20,7 +20,6 @@ class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"
DATASOURCE_FILE = "datasource_file"
@staticmethod
def value_of(value):

View File

@@ -51,7 +51,7 @@ class File(BaseModel):
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
mime_type: Optional[str] = None
size: int = -1

View File

@@ -1,67 +0,0 @@
import base64
import logging
import time
from typing import Optional
from configs import dify_config
from constants import IMAGE_EXTENSIONS
from core.helper.url_signer import UrlSigner
from extensions.ext_storage import storage
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in IMAGE_EXTENSIONS:
return None
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
return cls.get_signed_temp_image_url(upload_file.id)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.exception(f"File not found: {upload_file.key}")
return None
encoded_string = base64.b64encode(data).decode("utf-8")
return f"data:{upload_file.mime_type};base64,{encoded_string}"
@classmethod
def get_signed_temp_image_url(cls, upload_file_id) -> str:
"""
get signed url from upload file
:param upload_file_id: the id of UploadFile object
:return:
"""
base_url = dify_config.FILES_URL
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
result = UrlSigner.verify(
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
)
# verify signature
if not result:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

View File

@@ -28,7 +28,7 @@ class TemplateTransformer(ABC):
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
return result.group(1)
@classmethod
@@ -38,16 +38,53 @@ class TemplateTransformer(ABC):
:param response: response
:return:
"""
try:
result = json.loads(cls.extract_result_str_from_response(response))
except json.JSONDecodeError:
raise ValueError("failed to parse response")
result_str = cls.extract_result_str_from_response(response)
result = json.loads(result_str)
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
except ValueError as e:
# Re-raise ValueError from extract_result_str_from_response
raise e
except Exception as e:
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
# Check if the result contains an error
if isinstance(result, dict) and "error" in result:
raise ValueError(f"JavaScript execution error: {result['error']}")
if not isinstance(result, dict):
raise ValueError("result must be a dict")
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
if not all(isinstance(k, str) for k in result):
raise ValueError("result keys must be strings")
raise ValueError("Result keys must be strings")
# Post-process the result to convert scientific notation strings back to numbers
result = cls._post_process_result(result)
return result
@classmethod
def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
"""
Post-process the result to convert scientific notation strings back to numbers
"""
def convert_scientific_notation(value):
if isinstance(value, str):
# Check if the string looks like scientific notation
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
try:
return float(value)
except ValueError:
pass
elif isinstance(value, dict):
return {k: convert_scientific_notation(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_scientific_notation(v) for v in value]
return value
return convert_scientific_notation(result) # type: ignore[no-any-return]
@classmethod
@abstractmethod
def get_runner_script(cls) -> str:

View File

@@ -1,22 +0,0 @@
from collections import OrderedDict
from typing import Any
class LRUCache:
def __init__(self, capacity: int):
self.cache: OrderedDict[Any, Any] = OrderedDict()
self.capacity = capacity
def get(self, key: Any) -> Any:
if key not in self.cache:
return None
else:
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
return self.cache[key]
def put(self, key: Any, value: Any) -> None:
if key in self.cache:
self.cache.move_to_end(key)
self.cache[key] = value
if len(self.cache) > self.capacity:
self.cache.popitem(last=False) # pop the first item

View File

@@ -317,9 +317,10 @@ class IndexingRunner:
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
if image_file:
storage.delete(image_file.key)
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \

View File

@@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
@@ -170,10 +171,15 @@ def invoke_llm_with_structured_output(
system_fingerprint: Optional[str] = None
for event in llm_result:
if isinstance(event, LLMResultChunk):
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
if isinstance(event.delta.message.content, str):
result_text += event.delta.message.content
prompt_messages = event.prompt_messages
system_fingerprint = event.system_fingerprint
elif isinstance(event.delta.message.content, list):
for item in event.delta.message.content:
if isinstance(item, TextPromptMessageContent):
result_text += item.data
yield LLMResultChunkWithStructuredOutput(
model=model_schema.model,

View File

@@ -0,0 +1,342 @@
import base64
import hashlib
import json
import os
import secrets
import urllib.parse
from typing import Optional
from urllib.parse import urljoin
import requests
from pydantic import BaseModel, ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
)
from extensions.ext_redis import redis_client
LATEST_PROTOCOL_VERSION = "1.0"
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
class OAuthCallbackState(BaseModel):
provider_id: str
tenant_id: str
server_url: str
metadata: OAuthMetadata | None = None
client_information: OAuthClientInformation
code_verifier: str
redirect_uri: str
def generate_pkce_challenge() -> tuple[str, str]:
"""Generate PKCE challenge and verifier."""
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
return code_verifier, code_challenge
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
# Generate a secure random state key
state_key = secrets.token_urlsafe(32)
# Store the state data in Redis with expiration
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
return state_key
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
state_data = redis_client.get(redis_key)
if not state_data:
raise ValueError("State parameter has expired or does not exist")
# Delete the state data from Redis immediately after retrieval to prevent reuse
redis_client.delete(redis_key)
try:
# Parse and validate the state data
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except ValidationError as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,
full_state_data.client_information,
authorization_code,
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
try:
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
response = requests.get(url, headers=headers)
if response.status_code == 404:
return None
if not response.ok:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
except requests.RequestException as e:
if isinstance(e, requests.ConnectionError):
response = requests.get(url)
if response.status_code == 404:
return None
if not response.ok:
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
return OAuthMetadata.model_validate(response.json())
raise
def start_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
provider_id: str,
tenant_id: str,
) -> tuple[str, str]:
"""Begins the authorization flow with secure Redis state storage."""
response_type = "code"
code_challenge_method = "S256"
if metadata:
authorization_url = metadata.authorization_endpoint
if response_type not in metadata.response_types_supported:
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
if (
not metadata.code_challenge_methods_supported
or code_challenge_method not in metadata.code_challenge_methods_supported
):
raise ValueError(
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
)
else:
authorization_url = urljoin(server_url, "/authorize")
code_verifier, code_challenge = generate_pkce_challenge()
# Prepare state data with all necessary information
state_data = OAuthCallbackState(
provider_id=provider_id,
tenant_id=tenant_id,
server_url=server_url,
metadata=metadata,
client_information=client_information,
code_verifier=code_verifier,
redirect_uri=redirect_url,
)
# Store state data in Redis and generate secure state key
state_key = _create_secure_redis_state(state_data)
params = {
"response_type": response_type,
"client_id": client_information.client_id,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": state_key,
}
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
def exchange_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
authorization_code: str,
code_verifier: str,
redirect_uri: str,
) -> OAuthTokens:
"""Exchanges an authorization code for an access token."""
grant_type = "authorization_code"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"code": authorization_code,
"code_verifier": code_verifier,
"redirect_uri": redirect_uri,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
return OAuthTokens.model_validate(response.json())
def refresh_authorization(
server_url: str,
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
refresh_token: str,
) -> OAuthTokens:
"""Exchange a refresh token for an updated access token."""
grant_type = "refresh_token"
if metadata:
token_url = metadata.token_endpoint
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
else:
token_url = urljoin(server_url, "/token")
params = {
"grant_type": grant_type,
"client_id": client_information.client_id,
"refresh_token": refresh_token,
}
if client_information.client_secret:
params["client_secret"] = client_information.client_secret
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
return OAuthTokens.parse_obj(response.json())
def register_client(
server_url: str,
metadata: Optional[OAuthMetadata],
client_metadata: OAuthClientMetadata,
) -> OAuthClientInformationFull:
"""Performs OAuth 2.0 Dynamic Client Registration."""
if metadata:
if not metadata.registration_endpoint:
raise ValueError("Incompatible auth server: does not support dynamic client registration")
registration_url = metadata.registration_endpoint
else:
registration_url = urljoin(server_url, "/register")
response = requests.post(
registration_url,
json=client_metadata.model_dump(),
headers={"Content-Type": "application/json"},
)
if not response.ok:
response.raise_for_status()
return OAuthClientInformationFull.model_validate(response.json())
def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
# Handle client registration if needed
client_information = provider.client_information()
if not client_information:
if authorization_code is not None:
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
try:
full_information = register_client(server_url, metadata, provider.client_metadata)
except requests.RequestException as e:
raise ValueError(f"Could not register OAuth client: {e}")
provider.save_client_information(full_information)
client_information = full_information
# Exchange authorization code for tokens
if authorization_code is not None:
if not state_param:
raise ValueError("State parameter is required when exchanging authorization code")
try:
# Retrieve state data from Redis using state key
full_state_data = _retrieve_redis_state(state_param)
code_verifier = full_state_data.code_verifier
redirect_uri = full_state_data.redirect_uri
if not code_verifier or not redirect_uri:
raise ValueError("Missing code_verifier or redirect_uri in state data")
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Invalid state parameter: {e}")
tokens = exchange_authorization(
server_url,
metadata,
client_information,
authorization_code,
code_verifier,
redirect_uri,
)
provider.save_tokens(tokens)
return {"result": "success"}
provider_tokens = provider.tokens()
# Handle token refresh or new authorization
if provider_tokens and provider_tokens.refresh_token:
try:
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
provider.save_tokens(new_tokens)
return {"result": "success"}
except Exception as e:
raise ValueError(f"Could not refresh OAuth tokens: {e}")
# Start new authorization flow
authorization_url, code_verifier = start_authorization(
server_url,
metadata,
client_information,
provider.redirect_url,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
)
provider.save_code_verifier(code_verifier)
return {"authorization_url": authorization_url}

View File

@@ -0,0 +1,81 @@
from typing import Optional
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
class OAuthClientProvider:
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
@property
def client_metadata(self) -> OAuthClientMetadata:
"""Metadata about this OAuth client."""
return OAuthClientMetadata(
redirect_uris=[self.redirect_url],
token_endpoint_auth_method="none",
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
)
def client_information(self) -> Optional[OAuthClientInformation]:
"""Loads information about this OAuth client."""
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> Optional[OAuthTokens]:
"""Loads any existing OAuth tokens for the current session."""
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
access_token=credentials.get("access_token", ""),
token_type=credentials.get("token_type", "Bearer"),
expires_in=int(credentials.get("expires_in", "3600") or 3600),
refresh_token=credentials.get("refresh_token", ""),
)
def save_tokens(self, tokens: OAuthTokens) -> None:
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None:
"""Saves a PKCE code verifier for the current session."""
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

View File

@@ -0,0 +1,361 @@
import logging
import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
from sseclient import SSEClient
from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
@final
class _StatusReady:
def __init__(self, endpoint_url: str):
self._endpoint_url = endpoint_url
@final
class _StatusError:
def __init__(self, exc: Exception):
self._exc = exc
# Type aliases for better readability
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
def remove_request_params(url: str) -> str:
"""Remove request parameters from URL, keeping only the path."""
return urljoin(url, urlparse(url).path)
class SSETransport:
"""SSE client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> None:
"""Initialize the SSE transport.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
Args:
endpoint_url: The endpoint URL to validate.
Returns:
True if valid, False otherwise.
"""
url_parsed = urlparse(self.url)
endpoint_parsed = urlparse(endpoint_url)
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
"""Handle an 'endpoint' SSE event.
Args:
sse_data: The SSE event data.
status_queue: Queue to put status updates.
"""
endpoint_url = urljoin(self.url, sse_data)
logger.info(f"Received endpoint URL: {endpoint_url}")
if not self._validate_endpoint_url(endpoint_url):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
logger.error(error_msg)
status_queue.put(_StatusError(ValueError(error_msg)))
return
status_queue.put(_StatusReady(endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
"""Handle a 'message' SSE event.
Args:
sse_data: The SSE event data.
read_queue: Queue to put parsed messages.
"""
try:
message = types.JSONRPCMessage.model_validate_json(sse_data)
logger.debug(f"Received server message: {message}")
session_message = SessionMessage(message)
read_queue.put(session_message)
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Handle a single SSE event.
Args:
sse: The SSE event object.
read_queue: Queue for message events.
status_queue: Queue for status events.
"""
match sse.event:
case "endpoint":
self._handle_endpoint_event(sse.data, status_queue)
case "message":
self._handle_message_event(sse.data, read_queue)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Read and process SSE events.
Args:
event_source: The SSE event source.
read_queue: Queue to put received messages.
status_queue: Queue to put status updates.
"""
try:
for sse in event_source.iter_sse():
self._handle_sse_event(sse, read_queue, status_queue)
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
"""Send a single message to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send to.
message: The message to send.
"""
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
"""Handle writing messages to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send messages to.
write_queue: Queue to read messages from.
"""
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
self._send_message(client, endpoint_url, message)
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug(f"Post writer shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
"""Wait for the endpoint URL from the status queue.
Args:
status_queue: Queue to read status from.
Returns:
The endpoint URL.
Raises:
ValueError: If endpoint URL is not received or there's an error.
"""
try:
status = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if isinstance(status, _StatusReady):
return status._endpoint_url
elif isinstance(status, _StatusError):
raise status._exc
else:
raise ValueError("failed to get endpoint URL")
def connect(
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.
Args:
executor: Thread pool executor.
client: HTTP client.
event_source: SSE event source.
Returns:
Tuple of (read_queue, write_queue).
"""
read_queue: ReadQueue = queue.Queue()
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
# Wait for endpoint URL
endpoint_url = self._wait_for_endpoint(status_queue)
self.endpoint_url = endpoint_url
# Start post writer thread
executor.submit(self.post_writer, client, endpoint_url, write_queue)
return read_queue, write_queue
@contextmanager
def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
Yields:
Tuple of (read_queue, write_queue) for message communication.
"""
transport = SSETransport(url, headers, timeout, sse_read_timeout)
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
"""
Send a message to the server using the provided HTTP client.
Args:
http_client: The HTTP client to use for sending
endpoint_url: The endpoint URL to send the message to
session_message: The message to send
"""
try:
response = http_client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception as exc:
logger.exception("Error sending message")
raise
def read_messages(
sse_client: SSEClient,
) -> Generator[SessionMessage | Exception, None, None]:
"""
Read messages from the SSE client.
Args:
sse_client: The SSE client to read from
Yields:
SessionMessage or Exception for each event received
"""
try:
for sse in sse_client.events():
if sse.event == "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
yield SessionMessage(message)
except Exception as exc:
logger.exception("Error parsing server message")
yield exc
else:
logger.warning(f"Unknown SSE event: {sse.event}")
except Exception as exc:
logger.exception("Error reading SSE messages")
yield exc

View File

@@ -0,0 +1,476 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import logging
import queue
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
SessionMessage,
)
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception | None
# Queue types with clearer names for their roles
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "Accept"
JSON = "application/json"
SSE = "text/event-stream"
DEFAULT_QUEUE_READ_TIMEOUT = 3
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
pass
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
pass
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.Client
headers: dict[str, str]
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
server_to_client_queue: ServerToClientQueue # Renamed for clarity
sse_read_timeout: timedelta
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.session_id: str | None = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _handle_sse_event(
self,
sse: ServerSentEvent,
server_to_client_queue: ServerToClientQueue,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], None] | None = None,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
# Put message in queue that goes to client
server_to_client_queue.put(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
def handle_get_stream(
self,
client: httpx.Client,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
return
headers = self._update_headers_with_session(self.request_headers)
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
client=client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
logger.debug(f"GET stream error (non-fatal): {exc}")
def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._update_headers_with_session(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token")
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
original_request_id = ctx.session_message.message.root.id
with ssrf_proxy_sse_connect(
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
client=ctx.client,
method="GET",
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
break
def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._update_headers_with_session(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404:
if isinstance(message.root, JSONRPCRequest):
self._send_session_terminated_error(
ctx.server_to_client_queue,
message.root.id,
)
return
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
if content_type.startswith(JSON):
self._handle_json_response(response, ctx.server_to_client_queue)
elif content_type.startswith(SSE):
self._handle_sse_response(response, ctx)
else:
self._handle_unexpected_content_type(
content_type,
ctx.server_to_client_queue,
)
def _handle_json_response(
self,
response: httpx.Response,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle JSON response from the server."""
try:
content = response.read()
message = JSONRPCMessage.model_validate_json(content)
session_message = SessionMessage(message)
server_to_client_queue.put(session_message)
except Exception as exc:
server_to_client_queue.put(exc)
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
except Exception as e:
ctx.server_to_client_queue.put(e)
def _handle_unexpected_content_type(
self,
content_type: str,
server_to_client_queue: ServerToClientQueue,
) -> None:
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}"
logger.error(error_msg)
server_to_client_queue.put(ValueError(error_msg))
def _send_session_terminated_error(
self,
server_to_client_queue: ServerToClientQueue,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated by server"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
server_to_client_queue.put(session_message)
def post_writer(
self,
client: httpx.Client,
client_to_server_queue: ClientToServerQueue,
server_to_client_queue: ServerToClientQueue,
start_get_stream: Callable[[], None],
) -> None:
"""Handle writing requests to the server.
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while True:
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if session_message is None:
break
message = session_message.message
metadata = (
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
sse_read_timeout=self.sse_read_timeout,
)
if is_resumption:
self._handle_resumption_request(ctx)
else:
self._handle_post_request(ctx)
except queue.Empty:
continue
except Exception as exc:
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None:
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._update_headers_with_session(self.request_headers)
response = client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code != 200:
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@contextmanager
def streamablehttp_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: timedelta = timedelta(seconds=30),
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
terminate_on_close: bool = True,
) -> Generator[
tuple[
ServerToClientQueue, # Queue for receiving messages FROM server
ClientToServerQueue, # Queue for sending messages TO server
GetSessionIdCallback,
],
None,
None,
]:
"""
Client transport for StreamableHTTP.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Yields:
Tuple containing:
- server_to_client_queue: Queue for reading messages FROM the server
- client_to_server_queue: Queue for sending messages TO the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
# Create queues with clear directional meaning
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream() -> None:
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)

19
api/core/mcp/entities.py Normal file
View File

@@ -0,0 +1,19 @@
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from core.mcp.session.base_session import BaseSession
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT

10
api/core/mcp/error.py Normal file
View File

@@ -0,0 +1,10 @@
class MCPError(Exception):
pass
class MCPConnectionError(MCPError):
pass
class MCPAuthError(MCPConnectionError):
pass

150
api/core/mcp/mcp_client.py Normal file
View File

@@ -0,0 +1,150 @@
import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any, Optional, cast
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import Tool
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(
self,
server_url: str,
provider_id: str,
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
):
# Initialize info
self.provider_id = provider_id
self.tenant_id = tenant_id
self.client_type = "streamable"
self.server_url = server_url
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects
self._session: Optional[ClientSession] = None
self._streams_context: Optional[AbstractContextManager[Any]] = None
self._session_context: Optional[ClientSession] = None
self.exit_stack = ExitStack()
# Whether the client has been initialized
self._initialized = False
def __enter__(self):
self._initialize()
self._initialized = True
return self
def __exit__(
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
):
self.cleanup()
def _initialize(
self,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
"mcp": streamablehttp_client,
"sse": sse_client,
}
parsed_url = urlparse(self.server_url)
path = parsed_url.path
method_name = path.rstrip("/").split("/")[-1] if path else ""
try:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(sse_client, "sse")
except MCPConnectionError:
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
):
from core.mcp.auth.auth_flow import auth
try:
headers = (
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
if self.authed and self.token
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
if self._streams_context is None:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self.exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self.exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
except MCPAuthError:
if not self.authed:
raise
try:
auth(self.provider, self.server_url, self.authorization_code)
except Exception as e:
raise ValueError(f"Failed to authenticate: {e}")
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
except MCPConnectionError:
raise
def list_tools(self) -> list[Tool]:
"""Connect to an MCP server running with SSE transport"""
# List available tools to verify connection
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
tools = response.tools
return tools
def invoke_tool(self, tool_name: str, tool_args: dict):
"""Call a tool"""
if not self._initialized or not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
def cleanup(self):
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@@ -0,0 +1,224 @@
import json
import logging
from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
logger = logging.getLogger(__name__)
class MCPServerStreamableHTTPRequestHandler:
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property
def request_type(self):
return type(self.request.root)
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
**parameters,
},
"required": ["query", *required],
}
@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
)
def response(self, response: types.Result | str):
if isinstance(response, str):
sse_content = f"event: ping\ndata: {response}\n\n".encode()
yield sse_content
return
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
def error_response(self, code: int, message: str, data=None):
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
types.InitializedNotification: self.handle_notification,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
logger.exception("Invalid params")
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
logger.exception("Internal server error")
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
def handle_notification(self):
return "ping"
def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
clinet_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=clinet_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
instructions=self.mcp_server.description,
)
def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.app.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)
def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments
if not args:
raise ValueError("No arguments provided")
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
args = {"query": "", "inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(
self.app,
self.end_user,
args,
InvokeFrom.SERVICE_API,
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
)
answer = ""
if isinstance(response, RateLimitGenerator):
for item in response.generator:
data = item
if isinstance(data, str) and data.startswith("data: "):
try:
json_str = data[6:].strip()
parsed_data = json.loads(json_str)
if parsed_data.get("event") == "agent_thought":
answer += parsed_data.get("thought", "")
except json.JSONDecodeError:
continue
if isinstance(response, Mapping):
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters: dict[str, dict[str, Any]] = {}
required = []
for item in user_input_form:
parameters[item.variable] = {}
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
# if the workflow republished, the parameters not changed
# we should not raise error here
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "float"
return parameters, required

View File

@@ -0,0 +1,397 @@
import logging
import queue
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Self, TypeVar
from httpx import HTTPStatusError
from pydantic import BaseModel
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
MessageMetadata,
RequestId,
RequestParams,
ServerMessageMetadata,
ServerNotification,
ServerRequest,
ServerResult,
SessionMessage,
)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
request: ReceiveRequestT
_session: Any
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
self._completed = True
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is a context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_receive_request_type: type[ReceiveRequestT]
_receive_notification_type: type[ReceiveNotificationT]
def __init__(
self,
read_stream: queue.Queue,
write_stream: queue.Queue,
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def check_receiver_status(self) -> None:
if self._receiver_future.done():
self._receiver_future.result()
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._read_stream.put(None)
self._write_stream.put(None)
def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
self.check_receiver_status()
request_id = self._request_id
self._request_id = request_id + 1
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
self._response_streams[request_id] = response_queue
try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
)
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
if request_read_timeout_seconds is not None:
timeout = float(request_read_timeout_seconds.total_seconds())
elif self._session_read_timeout_seconds is not None:
timeout = float(self._session_read_timeout_seconds.total_seconds())
while True:
try:
response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
self.check_receiver_status()
continue
if response_or_error is None:
raise MCPConnectionError(
ErrorData(
code=500,
message="No response received",
)
)
elif isinstance(response_or_error, JSONRPCError):
if response_or_error.error.code == 401:
raise MCPAuthError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
raise MCPConnectionError(
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
)
else:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
def send_notification(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
self.check_receiver_status()
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
self._write_stream.put(session_message)
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
self._write_stream.put(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
self._write_stream.put(session_message)
def _receive_loop(self) -> None:
"""
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while True:
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, HTTPStatusError):
response_queue = self._response_streams.get(self._request_id - 1)
if response_queue is not None:
response_queue.put(
JSONRPCError(
jsonrpc="2.0",
id=self._request_id - 1,
error=ErrorData(code=message.response.status_code, message=message.args[0]),
)
)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
elif isinstance(message, Exception):
self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)
self._in_flight[responder.request_id] = responder
self._received_request(responder)
if not responder._completed:
self._handle_incoming(responder)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
self._in_flight[cancelled_id].cancel()
else:
self._received_notification(notification)
self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
else: # Response or error
response_queue = self._response_streams.get(message.message.root.id)
if response_queue is not None:
response_queue.put(message.message.root)
else:
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception as e:
logging.exception("Error in message processing loop")
raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
pass
def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
pass
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""
Sends a progress notification for a request that is currently being
processed.
"""
pass
def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass

View File

@@ -0,0 +1,365 @@
from datetime import timedelta
from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs import dify_config
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
class SamplingFnT(Protocol):
def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ...
class ListRootsFnT(Protocol):
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
class LoggingFnT(Protocol):
def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ...
class MessageHandlerFnT(Protocol):
def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ...
def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)):
pass
def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream,
write_stream,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)
result = self.send_request(
types.ClientRequest(
types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
experimental=None,
roots=roots,
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self.send_notification(
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
)
return result
def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return self.send_request(
types.ClientRequest(
types.PingRequest(
method="ping",
)
),
types.EmptyResult,
)
def send_progress_notification(
self, progress_token: str | int, progress: float, total: float | None = None
) -> None:
"""Send a progress notification."""
self.send_notification(
types.ClientNotification(
types.ProgressNotification(
method="notifications/progress",
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
),
),
)
)
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return self.send_request(
types.ClientRequest(
types.SetLevelRequest(
method="logging/setLevel",
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
def list_resources(self) -> types.ListResourcesResult:
"""Send a resources/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourcesRequest(
method="resources/list",
)
),
types.ListResourcesResult,
)
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request."""
return self.send_request(
types.ClientRequest(
types.ListResourceTemplatesRequest(
method="resources/templates/list",
)
),
types.ListResourceTemplatesResult,
)
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
method="resources/read",
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return self.send_request(
types.ClientRequest(
types.SubscribeRequest(
method="resources/subscribe",
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return self.send_request(
types.ClientRequest(
types.UnsubscribeRequest(
method="resources/unsubscribe",
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
return self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
)
def list_prompts(self) -> types.ListPromptsResult:
"""Send a prompts/list request."""
return self.send_request(
types.ClientRequest(
types.ListPromptsRequest(
method="prompts/list",
)
),
types.ListPromptsResult,
)
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return self.send_request(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
def complete(
self,
ref: types.ResourceReference | types.PromptReference,
argument: dict[str, str],
) -> types.CompleteResult:
"""Send a completion/complete request."""
return self.send_request(
types.ClientRequest(
types.CompleteRequest(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
),
)
),
types.CompleteResult,
)
def list_tools(self) -> types.ListToolsResult:
"""Send a tools/list request."""
return self.send_request(
types.ClientRequest(
types.ListToolsRequest(
method="tools/list",
)
),
types.ListToolsResult,
)
def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
self.send_notification(
types.ClientNotification(
types.RootsListChangedNotification(
method="notifications/roots/list_changed",
)
)
)
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
responder.respond(client_response)
case types.ListRootsRequest():
with responder:
list_roots_response = self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(list_roots_response)
responder.respond(client_response)
case types.PingRequest():
with responder:
return responder.respond(types.ClientResult(root=types.EmptyResult()))
def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
self._message_handler(req)
def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
self._logging_callback(params)
case _:
pass

1217
api/core/mcp/types.py Normal file

File diff suppressed because it is too large Load Diff

114
api/core/mcp/utils.py Normal file
View File

@@ -0,0 +1,114 @@
import json
import httpx
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
from core.model_runtime.utils.encoders import jsonable_encoder
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(
jsonrpc="2.0",
id=request_id or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content

View File

@@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
latency=0.0,
)
@classmethod
def from_metadata(cls, metadata: dict) -> "LLMUsage":
"""
Create LLMUsage instance from metadata dictionary with default values.
Args:
metadata: Dictionary containing usage metadata
Returns:
LLMUsage instance with values from metadata or defaults
"""
total_tokens = metadata.get("total_tokens", 0)
completion_tokens = metadata.get("completion_tokens", 0)
if total_tokens > 0 and completion_tokens == 0:
completion_tokens = total_tokens
return cls(
prompt_tokens=metadata.get("prompt_tokens", 0),
completion_tokens=completion_tokens,
total_tokens=total_tokens,
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
total_price=Decimal(str(metadata.get("total_price", 0))),
currency=metadata.get("currency", "USD"),
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
completion_price=Decimal(str(metadata.get("completion_price", 0))),
latency=metadata.get("latency", 0.0),
)
def plus(self, other: "LLMUsage") -> "LLMUsage":
"""
Add two LLMUsage instances together.

View File

@@ -0,0 +1,487 @@
import json
import logging
from collections.abc import Sequence
from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Status, StatusCode
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
TraceClient,
convert_datetime_to_nanoseconds,
convert_to_span_id,
convert_to_trace_id,
generate_span_id,
)
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
from core.ops.aliyun_trace.entities.semconv import (
GEN_AI_COMPLETION,
GEN_AI_FRAMEWORK,
GEN_AI_MODEL_NAME,
GEN_AI_PROMPT,
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
GEN_AI_RESPONSE_FINISH_REASON,
GEN_AI_SESSION_ID,
GEN_AI_SPAN_KIND,
GEN_AI_SYSTEM,
GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
GEN_AI_USAGE_TOTAL_TOKENS,
GEN_AI_USER_ID,
INPUT_VALUE,
OUTPUT_VALUE,
RETRIEVAL_DOCUMENT,
RETRIEVAL_QUERY,
TOOL_DESCRIPTION,
TOOL_NAME,
TOOL_PARAMETERS,
GenAISpanKind,
)
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import AliyunConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.rag.models.document import Document
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.nodes import NodeType
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
logger = logging.getLogger(__name__)
class AliyunDataTrace(BaseTraceInstance):
def __init__(
self,
aliyun_config: AliyunConfig,
):
super().__init__(aliyun_config)
base_url = aliyun_config.endpoint.rstrip("/")
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
pass
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
pass
def api_check(self):
return self.trace_client.api_check()
def get_project_url(self):
try:
return self.trace_client.get_project_url()
except Exception as e:
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"Aliyun get run url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
workflow_node_executions = self.get_workflow_node_executions(trace_info)
for node_execution in workflow_node_executions:
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
self.trace_client.add_span(node_span)
def message_trace(self, trace_info: MessageTraceInfo):
message_data = trace_info.message_data
if message_data is None:
return
message_id = trace_info.message_id
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
trace_id = convert_to_trace_id(message_id)
message_span_id = convert_to_span_id(message_id, "message")
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(message_span)
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
pre_prompt = getattr(app_model_config, "pre_prompt", "")
inputs_data = getattr(trace_info.message_data, "inputs", {})
llm_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=convert_to_span_id(message_id, "llm"),
name="llm",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: str(trace_info.outputs),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.outputs),
},
status=status,
)
self.trace_client.add_span(llm_span)
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
documents_data = extract_retrieval_documents(trace_info.documents)
dataset_retrieval_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name="dataset_retrieval",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: str(trace_info.inputs),
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
INPUT_VALUE: str(trace_info.inputs),
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
},
)
self.trace_client.add_span(dataset_retrieval_span)
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
return
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
tool_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=generate_span_id(),
name=trace_info.tool_name,
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: trace_info.tool_name,
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: str(trace_info.tool_outputs),
},
status=status,
)
self.trace_client.add_span(tool_span)
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
# Find the app's creator account
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).filter(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).filter(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")
service_account.set_tenant_id(current_tenant.tenant_id)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
session_factory=session_factory,
user=service_account,
app_id=trace_info.metadata.get("app_id"),
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
# Get all executions for this workflow run
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
workflow_run_id=trace_info.workflow_run_id
)
return workflow_node_executions
def build_workflow_node_span(
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
):
try:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
except Exception:
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
span_status: Status = Status(StatusCode.UNSET)
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
span_status = Status(StatusCode.OK)
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
span_status = Status(StatusCode.ERROR, str(node_execution.error))
return span_status
def build_workflow_task_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_tool_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
tool_des = {}
if node_execution.metadata:
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
GEN_AI_FRAMEWORK: "dify",
TOOL_NAME: node_execution.title,
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_retrieval_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
input_value = ""
if node_execution.inputs:
input_value = str(node_execution.inputs.get("query", ""))
output_value = ""
if node_execution.outputs:
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
GEN_AI_FRAMEWORK: "dify",
RETRIEVAL_QUERY: input_value,
RETRIEVAL_DOCUMENT: output_value,
INPUT_VALUE: input_value,
OUTPUT_VALUE: output_value,
},
status=self.get_workflow_node_status(node_execution),
)
def build_workflow_llm_span(
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
) -> SpanData:
process_data = node_execution.process_data or {}
outputs = node_execution.outputs or {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
return SpanData(
trace_id=trace_id,
parent_span_id=workflow_span_id,
span_id=convert_to_span_id(node_execution.id, "node"),
name=node_execution.title,
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
GEN_AI_COMPLETION: str(outputs.get("text", "")),
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
OUTPUT_VALUE: str(outputs.get("text", "")),
},
status=self.get_workflow_node_status(node_execution),
)
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
message_span_id = None
if trace_info.message_id:
message_span_id = convert_to_span_id(trace_info.message_id, "message")
user_id = trace_info.metadata.get("user_id")
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
if message_span_id: # chatflow
message_span = SpanData(
trace_id=trace_id,
parent_span_id=None,
span_id=message_span_id,
name="message",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(message_span)
workflow_span = SpanData(
trace_id=trace_id,
parent_span_id=message_span_id,
span_id=workflow_span_id,
name="workflow",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(workflow_span)
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
message_id = trace_info.message_id
status: Status = Status(StatusCode.OK)
if trace_info.error:
status = Status(StatusCode.ERROR, trace_info.error)
suggested_question_span = SpanData(
trace_id=convert_to_trace_id(message_id),
parent_span_id=convert_to_span_id(message_id, "message"),
span_id=convert_to_span_id(message_id, "suggested_question"),
name="suggested_question",
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
},
status=status,
)
self.trace_client.add_span(suggested_question_span)
def extract_retrieval_documents(documents: list[Document]):
documents_data = []
for document in documents:
document_data = {
"content": document.page_content,
"metadata": {
"dataset_id": document.metadata.get("dataset_id"),
"doc_id": document.metadata.get("doc_id"),
"document_id": document.metadata.get("document_id"),
},
"score": document.metadata.get("score"),
}
documents_data.append(document_data)
return documents_data

View File

@@ -0,0 +1,200 @@
import hashlib
import logging
import random
import socket
import threading
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import datetime
from typing import Optional
import requests
from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv.resource import ResourceAttributes
from configs import dify_config
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
INVALID_SPAN_ID = 0x0000000000000000
INVALID_TRACE_ID = 0x00000000000000000000000000000000
logger = logging.getLogger(__name__)
class TraceClient:
def __init__(
self,
service_name: str,
endpoint: str,
max_queue_size: int = 1000,
schedule_delay_sec: int = 5,
max_export_batch_size: int = 50,
):
self.endpoint = endpoint
self.resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
self.span_builder = SpanBuilder(self.resource)
self.exporter = OTLPSpanExporter(endpoint=endpoint)
self.max_queue_size = max_queue_size
self.schedule_delay_sec = schedule_delay_sec
self.max_export_batch_size = max_export_batch_size
self.queue: deque = deque(maxlen=max_queue_size)
self.condition = threading.Condition(threading.Lock())
self.done = False
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
self.worker_thread.start()
self._spans_dropped = False
def export(self, spans: Sequence[ReadableSpan]):
self.exporter.export(spans)
def api_check(self):
try:
response = requests.head(self.endpoint, timeout=5)
if response.status_code == 405:
return True
else:
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
logger.debug(f"AliyunTrace API check failed: {str(e)}")
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
def get_project_url(self):
return "https://arms.console.aliyun.com/#/llm"
def add_span(self, span_data: SpanData):
if span_data is None:
return
span: ReadableSpan = self.span_builder.build_span(span_data)
with self.condition:
if len(self.queue) == self.max_queue_size:
if not self._spans_dropped:
logger.warning("Queue is full, likely spans will be dropped.")
self._spans_dropped = True
self.queue.appendleft(span)
if len(self.queue) >= self.max_export_batch_size:
self.condition.notify()
def _worker(self):
while not self.done:
with self.condition:
if len(self.queue) < self.max_export_batch_size and not self.done:
self.condition.wait(timeout=self.schedule_delay_sec)
self._export_batch()
def _export_batch(self):
spans_to_export: list[ReadableSpan] = []
with self.condition:
while len(spans_to_export) < self.max_export_batch_size and self.queue:
spans_to_export.append(self.queue.pop())
if spans_to_export:
try:
self.exporter.export(spans_to_export)
except Exception as e:
logger.debug(f"Error exporting spans: {e}")
def shutdown(self):
with self.condition:
self.done = True
self.condition.notify_all()
self.worker_thread.join()
self._export_batch()
self.exporter.shutdown()
class SpanBuilder:
def __init__(self, resource):
self.resource = resource
self.instrumentation_scope = InstrumentationScope(
__name__,
"",
None,
None,
)
def build_span(self, span_data: SpanData) -> ReadableSpan:
span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
parent_span_context = None
if span_data.parent_span_id is not None:
parent_span_context = trace_api.SpanContext(
trace_id=span_data.trace_id,
span_id=span_data.parent_span_id,
is_remote=False,
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
trace_state=None,
)
span = ReadableSpan(
name=span_data.name,
context=span_context,
parent=parent_span_context,
resource=self.resource,
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,
instrumentation_scope=self.instrumentation_scope,
)
return span
def generate_span_id() -> int:
span_id = random.getrandbits(64)
while span_id == INVALID_SPAN_ID:
span_id = random.getrandbits(64)
return span_id
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
return uuid_obj.int
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
try:
uuid_obj = uuid.UUID(uuid_v4)
except Exception as e:
raise ValueError(f"Invalid UUID input: {e}")
combined_key = f"{uuid_obj.hex}-{span_type}"
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
return span_id
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
if start_time_a is None:
return None
timestamp_in_seconds = start_time_a.timestamp()
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
return timestamp_in_nanoseconds

View File

@@ -0,0 +1,21 @@
from collections.abc import Sequence
from typing import Optional
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event, Status, StatusCode
from pydantic import BaseModel, Field
class SpanData(BaseModel):
model_config = {"arbitrary_types_allowed": True}
trace_id: int = Field(..., description="The unique identifier for the trace.")
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
span_id: int = Field(..., description="The unique identifier for this span.")
name: str = Field(..., description="The name of the span.")
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")

View File

@@ -0,0 +1,64 @@
from enum import Enum
# public
GEN_AI_SESSION_ID = "gen_ai.session.id"
GEN_AI_USER_ID = "gen_ai.user.id"
GEN_AI_USER_NAME = "gen_ai.user.name"
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_FRAMEWORK = "gen_ai.framework"
# Chain
INPUT_VALUE = "input.value"
OUTPUT_VALUE = "output.value"
# Retriever
RETRIEVAL_QUERY = "retrieval.query"
RETRIEVAL_DOCUMENT = "retrieval.document"
# LLM
GEN_AI_MODEL_NAME = "gen_ai.model_name"
GEN_AI_SYSTEM = "gen_ai.system"
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
GEN_AI_PROMPT = "gen_ai.prompt"
GEN_AI_COMPLETION = "gen_ai.completion"
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
# Tool
TOOL_NAME = "tool.name"
TOOL_DESCRIPTION = "tool.description"
TOOL_PARAMETERS = "tool.parameters"
class GenAISpanKind(Enum):
CHAIN = "CHAIN"
RETRIEVER = "RETRIEVER"
RERANKER = "RERANKER"
LLM = "LLM"
EMBEDDING = "EMBEDDING"
TOOL = "TOOL"
AGENT = "AGENT"
TASK = "TASK"

View File

@@ -0,0 +1,726 @@
import hashlib
import json
import logging
import os
from datetime import datetime, timedelta
from typing import Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
from core.ops.base_trace_instance import BaseTraceInstance
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
TraceTaskName,
WorkflowTraceInfo,
)
from extensions.ext_database import db
from models.model import EndUser, MessageFile
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
try:
# Choose the appropriate exporter based on config type
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
if isinstance(arize_phoenix_config, ArizeConfig):
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
arize_headers = {
"api_key": arize_phoenix_config.api_key or "",
"space_id": arize_phoenix_config.space_id or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = GrpcOTLPSpanExporter(
endpoint=arize_endpoint,
headers=arize_headers,
timeout=30,
)
else:
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
phoenix_headers = {
"api_key": arize_phoenix_config.api_key or "",
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
}
exporter = HttpOTLPSpanExporter(
endpoint=phoenix_endpoint,
headers=phoenix_headers,
timeout=30,
)
attributes = {
"openinference.project.name": arize_phoenix_config.project or "",
"model_id": arize_phoenix_config.project or "",
}
resource = Resource(attributes=attributes)
provider = trace_sdk.TracerProvider(resource=resource)
processor = SimpleSpanProcessor(
exporter,
)
provider.add_span_processor(processor)
# Create a named tracer instead of setting the global provider
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
except Exception as e:
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
raise
def datetime_to_nanos(dt: Optional[datetime]) -> int:
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
if dt is None:
dt = datetime.now()
return int(dt.timestamp() * 1_000_000_000)
def uuid_to_trace_id(string: Optional[str]) -> int:
"""Convert UUID string to a valid trace ID (16-byte integer)."""
if string is None:
string = ""
hash_object = hashlib.sha256(string.encode())
# Take the first 16 bytes (128 bits) of the hash
digest = hash_object.digest()[:16]
# Convert to integer (128 bits)
return int.from_bytes(digest, byteorder="big")
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
arize_phoenix_config: ArizeConfig | PhoenixConfig,
):
super().__init__(arize_phoenix_config)
import logging
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
self.arize_phoenix_config = arize_phoenix_config
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
self.project = arize_phoenix_config.project
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
def trace(self, trace_info: BaseTraceInfo):
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
try:
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
if isinstance(trace_info, MessageTraceInfo):
self.message_trace(trace_info)
if isinstance(trace_info, ModerationTraceInfo):
self.moderation_trace(trace_info)
if isinstance(trace_info, SuggestedQuestionTraceInfo):
self.suggested_question_trace(trace_info)
if isinstance(trace_info, DatasetRetrievalTraceInfo):
self.dataset_retrieval_trace(trace_info)
if isinstance(trace_info, ToolTraceInfo):
self.tool_trace(trace_info)
if isinstance(trace_info, GenerateNameTraceInfo):
self.generate_name_trace(trace_info)
except Exception as e:
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
if trace_info.message_data is None:
return
workflow_metadata = {
"workflow_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
}
workflow_metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
workflow_span = self.tracer.start_span(
name=TraceTaskName.WORKFLOW_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
# Process workflow nodes
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
created_at = node_execution.created_at or datetime.now()
elapsed_time = node_execution.elapsed_time
finished_at = created_at + timedelta(seconds=elapsed_time)
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
node_metadata = {
"node_id": node_execution.id,
"node_type": node_execution.node_type,
"node_status": node_execution.status,
"tenant_id": node_execution.tenant_id,
"app_id": node_execution.app_id,
"app_name": node_execution.title,
"status": node_execution.status,
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
}
if node_execution.execution_metadata:
node_metadata.update(json.loads(node_execution.execution_metadata))
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN.value
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM.value
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_metadata["ls_provider"] = provider
if model:
node_metadata["ls_model_name"] = model
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL.value
else:
span_kind = OpenInferenceSpanKindValues.CHAIN.value
node_span = self.tracer.start_span(
name=node_execution.node_type,
attributes={
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
)
try:
if node_execution.node_type == "llm":
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
if model:
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
outputs = json.loads(node_execution.outputs).get("usage", {})
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
)
node_span.set_attribute(
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def message_trace(self, trace_info: MessageTraceInfo):
if trace_info.message_data is None:
return
file_list = cast(list[str], trace_info.file_list) or []
message_file_data: Optional[MessageFile] = trace_info.message_file_data
if message_file_data is not None:
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
file_list.append(file_url)
message_metadata = {
"message_id": trace_info.message_id or "",
"conversation_mode": str(trace_info.conversation_mode or ""),
"user_id": trace_info.message_data.from_account_id or "",
"file_list": json.dumps(file_list),
"status": trace_info.message_data.status or "",
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens or 0,
"prompt_tokens": trace_info.message_tokens or 0,
"completion_tokens": trace_info.answer_tokens or 0,
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
message_metadata.update(trace_info.metadata)
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
attributes = {
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
trace_id = uuid_to_trace_id(trace_info.message_id)
message_span_id = RandomIdGenerator().generate_span_id()
span_context = SpanContext(
trace_id=trace_id,
span_id=message_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
message_span = self.tracer.start_span(
name=TraceTaskName.MESSAGE_TRACE.value,
attributes=attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
message_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
# Convert outputs to string based on type
if isinstance(trace_info.outputs, dict | list):
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
elif isinstance(trace_info.outputs, str):
outputs_str = trace_info.outputs
else:
outputs_str = str(trace_info.outputs)
llm_attributes = {
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: outputs_str,
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
if isinstance(trace_info.inputs, list):
for i, msg in enumerate(trace_info.inputs):
if isinstance(msg, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
"role", "user"
)
# todo: handle assistant and tool role messages, as they don't always
# have a text field, but may have a tool_calls field instead
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
elif isinstance(trace_info.inputs, dict):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
elif isinstance(trace_info.inputs, str):
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
if trace_info.message_data.model_id is not None:
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
if trace_info.message_data.model_provider is not None:
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
llm_span = self.tracer.start_span(
name="llm",
attributes=llm_attributes,
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
llm_span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
finally:
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
def moderation_trace(self, trace_info: ModerationTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"message_id": trace_info.message_id,
"tool_name": "moderation",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.MODERATION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(
{
"action": trace_info.action,
"flagged": trace_info.flagged,
"preset_response": trace_info.preset_response,
"inputs": trace_info.inputs,
},
ensure_ascii=False,
),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "suggested_question",
"status": trace_info.status,
"status_message": trace_info.error or "",
"level": "ERROR" if trace_info.error else "DEFAULT",
"total_tokens": trace_info.total_tokens,
"ls_provider": trace_info.model_provider or "",
"ls_model_name": trace_info.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
if trace_info.message_data is None:
return
start_time = trace_info.start_time or trace_info.message_data.created_at
end_time = trace_info.end_time or trace_info.message_data.updated_at
metadata = {
"message_id": trace_info.message_id,
"tool_name": "dataset_retrieval",
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
"ls_provider": trace_info.message_data.model_provider or "",
"ls_model_name": trace_info.message_data.model_id or "",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
"start_time": start_time.isoformat() if start_time else "",
"end_time": end_time.isoformat() if end_time else "",
},
start_time=datetime_to_nanos(start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(end_time))
def tool_trace(self, trace_info: ToolTraceInfo):
if trace_info.message_data is None:
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
return
metadata = {
"message_id": trace_info.message_id,
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
}
trace_id = uuid_to_trace_id(trace_info.message_id)
tool_span_id = RandomIdGenerator().generate_span_id()
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
# Create span context with the same trace_id as the parent
# todo: Create with the appropriate parent span context, so that the tool span is
# a child of the appropriate span (e.g. message span)
span_context = SpanContext(
trace_id=trace_id,
span_id=tool_span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
tool_params_str = (
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
if isinstance(trace_info.tool_parameters, dict)
else str(trace_info.tool_parameters)
)
span = self.tracer.start_span(
name=trace_info.tool_name,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.TOOL_NAME: trace_info.tool_name,
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
)
try:
if trace_info.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
if trace_info.message_data is None:
return
metadata = {
"project_name": self.project,
"message_id": trace_info.message_id,
"status": trace_info.message_data.status,
"status_message": trace_info.message_data.error or "",
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
}
metadata.update(trace_info.metadata)
trace_id = uuid_to_trace_id(trace_info.message_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
span_id=span_id,
is_remote=False,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
trace_state=TraceState(),
)
span = self.tracer.start_span(
name=TraceTaskName.GENERATE_NAME_TRACE.value,
attributes={
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
},
start_time=datetime_to_nanos(trace_info.start_time),
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if trace_info.message_data.error:
span.add_event(
"exception",
attributes={
"exception.message": trace_info.message_data.error,
"exception.type": "Error",
"exception.stacktrace": trace_info.message_data.error,
},
)
finally:
span.end(end_time=datetime_to_nanos(trace_info.end_time))
def api_check(self):
try:
with self.tracer.start_span("api_check") as span:
span.set_attribute("test", "true")
return True
except Exception as e:
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
def get_project_url(self):
try:
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
return "https://app.arize.com/"
else:
return f"{self.arize_phoenix_config.endpoint}/projects/"
except Exception as e:
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
def _get_workflow_nodes(self, workflow_run_id: str):
"""Helper method to get workflow nodes"""
workflow_nodes = (
db.session.query(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.tenant_id,
WorkflowNodeExecutionModel.app_id,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.inputs,
WorkflowNodeExecutionModel.outputs,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes

View File

@@ -2,20 +2,92 @@ from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
class TracingProviderEnum(StrEnum):
ARIZE = "arize"
PHOENIX = "phoenix"
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
ALIYUN = "aliyun"
class BaseTracingConfig(BaseModel):
"""
Base model class for tracing
Base model class for tracing configurations
"""
...
@classmethod
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
"""
Common endpoint URL validation logic
Args:
v: URL value to validate
default_url: Default URL to use if input is None or empty
Returns:
Validated and normalized URL
"""
return validate_url(v, default_url)
@classmethod
def validate_project_field(cls, v: str, default_name: str) -> str:
"""
Common project name validation logic
Args:
v: Project name to validate
default_name: Default name to use if input is None or empty
Returns:
Validated project name
"""
return validate_project_name(v, default_name)
class ArizeConfig(BaseTracingConfig):
"""
Model class for Arize tracing config.
"""
api_key: str | None = None
space_id: str | None = None
project: str | None = None
endpoint: str = "https://otlp.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
class PhoenixConfig(BaseTracingConfig):
"""
Model class for Phoenix tracing config.
"""
api_key: str | None = None
project: str | None = None
endpoint: str = "https://app.phoenix.arize.com"
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "default")
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
@@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
@field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.langfuse.com"
if not v.startswith("https://") and not v.startswith("http://"):
raise ValueError("host must start with https:// or http://")
return v
def host_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
@@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://api.smith.langchain.com"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
def endpoint_validator(cls, v, info: ValidationInfo):
# LangSmith only allows HTTPS
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
@@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "Default Project"
return v
return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://www.comet.com/opik/api/"
if not v.startswith(("https://", "http://")):
raise ValueError("url must start with https:// or http://")
if not v.endswith("/api/"):
raise ValueError("url should ends with /api/")
return v
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
@@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo):
if v is None or v == "":
v = "https://trace.wandb.ai"
if not v.startswith("https://"):
raise ValueError("endpoint must start with https://")
return v
def endpoint_validator(cls, v, info: ValidationInfo):
# Weave only allows HTTPS for endpoint
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
@field_validator("host")
@classmethod
def validate_host(cls, v, info: ValidationInfo):
if v is not None and v != "":
if not v.startswith(("https://", "http://")):
raise ValueError("host must start with https:// or http://")
def host_validator(cls, v, info: ValidationInfo):
if v is not None and v.strip() != "":
return validate_url(v, v, allowed_schemes=("https", "http"))
return v
class AliyunConfig(BaseTracingConfig):
"""
Model class for Aliyun tracing config.
"""
app_name: str = "dify_app"
license_key: str
endpoint: str
@field_validator("app_name")
@classmethod
def app_name_validator(cls, v, info: ValidationInfo):
return cls.validate_project_field(v, "dify_app")
@field_validator("license_key")
@classmethod
def license_key_validator(cls, v, info: ValidationInfo):
if not v or v.strip() == "":
raise ValueError("License key cannot be empty")
return v
@field_validator("endpoint")
@classmethod
def endpoint_validator(cls, v, info: ValidationInfo):
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"

View File

@@ -135,4 +135,3 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
DATASOURCE_TRACE = "datasource"

View File

@@ -32,6 +32,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@@ -180,12 +181,9 @@ class LangFuseDataTrace(BaseTraceInstance):
prompt_tokens = 0
completion_tokens = 0
try:
if outputs.get("usage"):
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
else:
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@@ -293,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
input=trace_info.inputs,
output=message_data.answer,
metadata=metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
@@ -339,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)

Some files were not shown because too many files have changed in this diff Show More