mirror of
https://github.com/langgenius/dify.git
synced 2025-12-31 19:47:19 +00:00
Compare commits
107 Commits
feat/r2
...
feat/plugi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7608eb1049 | ||
|
|
b834131f50 | ||
|
|
5375d9bb27 | ||
|
|
535fff62f3 | ||
|
|
18b58424ec | ||
|
|
10858ea1dc | ||
|
|
95ce7b6f47 | ||
|
|
6f8c7a66c8 | ||
|
|
a371390d6c | ||
|
|
a316766ad7 | ||
|
|
a9cc19f530 | ||
|
|
881a151d30 | ||
|
|
785c4caa67 | ||
|
|
4403bc67a1 | ||
|
|
b237113311 | ||
|
|
4cb50f1809 | ||
|
|
1885426421 | ||
|
|
89b52471fb | ||
|
|
3643ed1014 | ||
|
|
e39236186d | ||
|
|
521488f926 | ||
|
|
d61ea5a2de | ||
|
|
816210d744 | ||
|
|
784a236280 | ||
|
|
1e0426ca6f | ||
|
|
f925869f61 | ||
|
|
f62b59a805 | ||
|
|
a4bdeba60d | ||
|
|
5c0cb7f912 | ||
|
|
2ffbf5435d | ||
|
|
71385d594d | ||
|
|
53c4912cbb | ||
|
|
1760179093 | ||
|
|
aded30b664 | ||
|
|
de54f8d0ef | ||
|
|
5b0b64c7e5 | ||
|
|
b654c852a5 | ||
|
|
c48b32c9e3 | ||
|
|
8f723697ef | ||
|
|
de22648b9f | ||
|
|
b9f56852dc | ||
|
|
108cc3486f | ||
|
|
ac69b8b191 | ||
|
|
8288145ee4 | ||
|
|
51f6095be7 | ||
|
|
a201e9faee | ||
|
|
fec6bafcda | ||
|
|
2639f950cc | ||
|
|
6663187eca | ||
|
|
13990f31a1 | ||
|
|
de39b737b6 | ||
|
|
a66ed7157e | ||
|
|
c9c49200e0 | ||
|
|
317d287458 | ||
|
|
fd7396d8f9 | ||
|
|
a0af33e945 | ||
|
|
a79f37b686 | ||
|
|
1c7404099d | ||
|
|
ed54bd5121 | ||
|
|
06c3deff11 | ||
|
|
47954aa284 | ||
|
|
f3c8625fe2 | ||
|
|
ebc4fdc4b2 | ||
|
|
1af3d40c1a | ||
|
|
31eb8548ef | ||
|
|
a45aa1e505 | ||
|
|
cb0d4a1e15 | ||
|
|
21e68b9cf1 | ||
|
|
a3654c8fe9 | ||
|
|
980b0188d2 | ||
|
|
daab648c78 | ||
|
|
e17b33e004 | ||
|
|
4e7c9dd2ae | ||
|
|
5487463385 | ||
|
|
8d8220b06c | ||
|
|
0625d6a361 | ||
|
|
63a1a1077e | ||
|
|
0af646d947 | ||
|
|
07c99745fa | ||
|
|
afd0d31354 | ||
|
|
18bbf1165d | ||
|
|
5f17edc77f | ||
|
|
836027cb33 | ||
|
|
f3cbfe2223 | ||
|
|
bc1e4c88e0 | ||
|
|
d114485abd | ||
|
|
3e8a4a66fe | ||
|
|
4c583f3d9a | ||
|
|
52b845a5bb | ||
|
|
38d1c85c57 | ||
|
|
c43d992f2b | ||
|
|
1ff5969b92 | ||
|
|
93a560ee54 | ||
|
|
2f241d932c | ||
|
|
a0804786fd | ||
|
|
c6fa8102eb | ||
|
|
7ec5816513 | ||
|
|
825fbcc6f8 | ||
|
|
ccef71626d | ||
|
|
29cac85b12 | ||
|
|
8b290ac7a1 | ||
|
|
01cdffaa08 | ||
|
|
3061280f7a | ||
|
|
bc75d810c4 | ||
|
|
dc5e974a78 | ||
|
|
baff25c160 | ||
|
|
42b6524954 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -6,7 +6,6 @@ on:
|
||||
- "main"
|
||||
- "deploy/dev"
|
||||
- "deploy/enterprise"
|
||||
- "deploy/rag-dev"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
7
.github/workflows/deploy-dev.yml
vendored
7
.github/workflows/deploy-dev.yml
vendored
@@ -4,7 +4,7 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "deploy/rag-dev"
|
||||
- "deploy/dev"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -12,13 +12,12 @@ jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'deploy/rag-dev'
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v0.1.8
|
||||
with:
|
||||
host: ${{ secrets.RAG_SSH_HOST }}
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_PRIVATE_KEY }}
|
||||
script: |
|
||||
|
||||
27
api/app.py
27
api/app.py
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@@ -16,20 +17,20 @@ else:
|
||||
# It seems that JetBrains Python debugger does not work well with gevent,
|
||||
# so we need to disable gevent in debug mode.
|
||||
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
|
||||
# if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
# from gevent import monkey
|
||||
#
|
||||
# # gevent
|
||||
# monkey.patch_all()
|
||||
#
|
||||
# from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
#
|
||||
# # grpc gevent
|
||||
# grpc_gevent.init_gevent()
|
||||
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
|
||||
from gevent import monkey
|
||||
|
||||
# import psycogreen.gevent # type: ignore
|
||||
#
|
||||
# psycogreen.gevent.patch_psycopg()
|
||||
# gevent
|
||||
monkey.patch_all()
|
||||
|
||||
from grpc.experimental import gevent as grpc_gevent # type: ignore
|
||||
|
||||
# grpc gevent
|
||||
grpc_gevent.init_gevent()
|
||||
|
||||
import psycogreen.gevent # type: ignore
|
||||
|
||||
psycogreen.gevent.patch_psycopg()
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
|
||||
@@ -222,28 +222,11 @@ class HostedFetchAppTemplateConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class HostedFetchPipelineTemplateConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for fetching pipeline templates
|
||||
"""
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_MODE: str = Field(
|
||||
description="Mode for fetching pipeline templates: remote, db, or builtin default to remote,",
|
||||
default="database",
|
||||
)
|
||||
|
||||
HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN: str = Field(
|
||||
description="Domain for fetching remote pipeline templates",
|
||||
default="https://tmpl.dify.ai",
|
||||
)
|
||||
|
||||
|
||||
class HostedServiceConfig(
|
||||
# place the configs in alphabet order
|
||||
HostedAnthropicConfig,
|
||||
HostedAzureOpenAiConfig,
|
||||
HostedFetchAppTemplateConfig,
|
||||
HostedFetchPipelineTemplateConfig,
|
||||
HostedMinmaxConfig,
|
||||
HostedOpenAiConfig,
|
||||
HostedSparkConfig,
|
||||
|
||||
@@ -3,7 +3,6 @@ from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
@@ -34,11 +33,3 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont
|
||||
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_schemas")
|
||||
)
|
||||
|
||||
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
|
||||
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
|
||||
)
|
||||
|
||||
datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("datasource_plugin_providers_lock")
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ from .app import (
|
||||
conversation,
|
||||
conversation_variables,
|
||||
generator,
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
ops_trace,
|
||||
@@ -76,6 +77,7 @@ from .billing import billing, compliance
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import (
|
||||
data_source,
|
||||
datasets,
|
||||
datasets_document,
|
||||
datasets_segments,
|
||||
@@ -84,14 +86,6 @@ from .datasets import (
|
||||
metadata,
|
||||
website,
|
||||
)
|
||||
from .datasets.rag_pipeline import (
|
||||
datasource_auth,
|
||||
datasource_content_preview,
|
||||
rag_pipeline,
|
||||
rag_pipeline_datasets,
|
||||
rag_pipeline_import,
|
||||
rag_pipeline_workflow,
|
||||
)
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
||||
@@ -90,23 +90,11 @@ class ChatMessageTextApi(Resource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
if text_to_speech is None:
|
||||
raise ValueError("TTS is not enabled")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
if app_model.app_model_config is None:
|
||||
raise ValueError("AppModelConfig not found")
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
|
||||
voice = args.get("voice", None)
|
||||
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
|
||||
)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
||||
102
api/controllers/console/app/mcp_server.py
Normal file
102
api/controllers/console/app/mcp_server.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
if args["status"] not in [status.value for status in AppMCPServerStatus]:
|
||||
raise ValueError("Invalid status")
|
||||
server.status = args["status"]
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
class AppMCPServerRefreshController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise NotFound()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
|
||||
if not server:
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
api.add_resource(AppMCPServerRefreshController, "/apps/<uuid:server_id>/server/refresh")
|
||||
@@ -283,15 +283,6 @@ class DatasetApi(Resource):
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid icon info.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
data = request.get_json()
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from datetime import UTC, datetime
|
||||
@@ -52,7 +51,6 @@ from fields.document_fields import (
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
|
||||
|
||||
@@ -663,7 +661,7 @@ class DocumentDetailApi(DocumentResource):
|
||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
@@ -1030,41 +1028,6 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter_by(document_id=document_id)
|
||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not log:
|
||||
return {
|
||||
"datasource_info": None,
|
||||
"datasource_type": None,
|
||||
"input_data": None,
|
||||
"datasource_node_id": None,
|
||||
}, 200
|
||||
return {
|
||||
"datasource_info": json.loads(log.datasource_info),
|
||||
"datasource_type": log.datasource_type,
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, "/datasets/process-rule")
|
||||
api.add_resource(DatasetDocumentListApi, "/datasets/<uuid:dataset_id>/documents")
|
||||
api.add_resource(DatasetInitApi, "/datasets/init")
|
||||
@@ -1087,6 +1050,3 @@ api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")
|
||||
api.add_resource(DocumentRenameApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename")
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync")
|
||||
api.add_resource(
|
||||
DocumentPipelineExecutionLogApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/pipeline-execution-log"
|
||||
)
|
||||
|
||||
@@ -101,9 +101,3 @@ class ChildChunkDeleteIndexError(BaseHTTPException):
|
||||
error_code = "child_chunk_delete_index_error"
|
||||
description = "Delete child chunk index failed: {message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class PipelineNotFoundError(BaseHTTPException):
|
||||
error_code = "pipeline_not_found"
|
||||
description = "Pipeline not found."
|
||||
code = 404
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
from flask import redirect, request
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
|
||||
class DatasourcePluginOauthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
# get all plugin oauth configs
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_url = (
|
||||
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
|
||||
)
|
||||
system_credentials = plugin_oauth_config.system_credentials
|
||||
if system_credentials:
|
||||
system_credentials["redirect_url"] = redirect_url
|
||||
response = oauth_handler.get_authorization_url(
|
||||
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
class DatasourceOauthCallback(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
credentials = oauth_handler.get_credentials(
|
||||
current_user.current_tenant.id,
|
||||
current_user.id,
|
||||
plugin_id,
|
||||
provider,
|
||||
system_credentials=plugin_oauth_config.system_credentials,
|
||||
request=request,
|
||||
)
|
||||
datasource_provider = DatasourceProvider(
|
||||
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
|
||||
)
|
||||
db.session.add(datasource_provider)
|
||||
db.session.commit()
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
|
||||
class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
try:
|
||||
datasource_provider_service.datasource_provider_credentials_validate(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
class DatasourceAuthUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, auth_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, auth_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
auth_id=auth_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
|
||||
return {"result": "success"}, 201
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOauthApi,
|
||||
"/oauth/plugin/datasource",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOauthCallback,
|
||||
"/oauth/plugin/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/plugin/datasource",
|
||||
)
|
||||
|
||||
api.add_resource(
|
||||
DatasourceAuthUpdateDeleteApi,
|
||||
"/auth/plugin/datasource/<string:auth_id>",
|
||||
)
|
||||
@@ -1,55 +0,0 @@
|
||||
from flask_restful import ( # type: ignore
|
||||
Resource, # type: ignore
|
||||
reqparse,
|
||||
)
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DataSourceContentPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def post(self, pipeline: Pipeline, node_id: str):
|
||||
"""
|
||||
Run datasource content preview
|
||||
"""
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
is_published=True,
|
||||
)
|
||||
return preview_content, 200
|
||||
|
||||
|
||||
api.add_resource(
|
||||
DataSourceContentPreviewApi,
|
||||
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||
)
|
||||
@@ -1,162 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
# get pipeline templates
|
||||
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||
return pipeline_templates, 200
|
||||
|
||||
|
||||
class PipelineTemplateDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
||||
return {"data": template.yaml_content}, 200
|
||||
|
||||
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
api.add_resource(
|
||||
PipelineTemplateListApi,
|
||||
"/rag/pipeline/templates",
|
||||
)
|
||||
api.add_resource(
|
||||
PipelineTemplateDetailApi,
|
||||
"/rag/pipeline/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
CustomizedPipelineTemplateApi,
|
||||
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||
)
|
||||
api.add_resource(
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||
)
|
||||
@@ -1,171 +0,0 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||
try:
|
||||
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
@@ -1,146 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,43 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
pipeline_id = kwargs.get("pipeline_id")
|
||||
pipeline_id = str(pipeline_id)
|
||||
|
||||
del kwargs["pipeline_id"]
|
||||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pipeline:
|
||||
raise PipelineNotFoundError()
|
||||
|
||||
kwargs["pipeline"] = pipeline
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -18,7 +18,6 @@ from controllers.console.app.error import (
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import AppMode
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -79,19 +78,9 @@ class ChatTextApi(InstalledAppResource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
|
||||
voice = args.get("voice", None)
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logging.exception("App model config broken.")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import io
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import send_file
|
||||
from flask import redirect, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,17 +10,34 @@ from werkzeug.exceptions import Forbidden
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@@ -34,7 +52,7 @@ class ToolProviderListApi(Resource):
|
||||
req.add_argument(
|
||||
"type",
|
||||
type=str,
|
||||
choices=["builtin", "model", "api", "workflow"],
|
||||
choices=["builtin", "model", "api", "workflow", "mcp"],
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="args",
|
||||
@@ -613,6 +631,166 @@ class ToolLabelsApi(Resource):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
||||
class ToolProviderMCPApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
user_id=user.id,
|
||||
server_identifier=args["server_identifier"],
|
||||
)
|
||||
)
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
MCPToolManageService.update_mcp_provider(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider_id=args["provider_id"],
|
||||
server_url=args["server_url"],
|
||||
name=args["name"],
|
||||
icon=args["icon"],
|
||||
icon_type=args["icon_type"],
|
||||
icon_background=args["icon_background"],
|
||||
server_identifier=args["server_identifier"],
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
class ToolMCPAuthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
provider_id = args["provider_id"]
|
||||
tenant_id = current_user.current_tenant_id
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
try:
|
||||
with MCPClient(
|
||||
provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
except MCPError as e:
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
mcp_provider=provider,
|
||||
credentials={},
|
||||
authed=False,
|
||||
)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
|
||||
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
|
||||
|
||||
class ToolMCPListAllApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user = current_user
|
||||
tenant_id = user.current_tenant_id
|
||||
|
||||
tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
|
||||
|
||||
return [tool.to_dict() for tool in tools]
|
||||
|
||||
|
||||
class ToolMCPUpdateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
tools = MCPToolManageService.list_mcp_tool_from_remote_server(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return jsonable_encoder(tools)
|
||||
|
||||
|
||||
class ToolMCPCallbackApi(Resource):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("code", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("state", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
state_key = args["state"]
|
||||
authorization_code = args["code"]
|
||||
handle_callback(state_key, authorization_code)
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
# tool provider
|
||||
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
|
||||
|
||||
@@ -647,8 +825,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid
|
||||
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
|
||||
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
|
||||
|
||||
# mcp tool provider
|
||||
api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/<path:provider_id>")
|
||||
api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
|
||||
api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/<path:provider_id>")
|
||||
api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
|
||||
api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
|
||||
|
||||
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
|
||||
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
|
||||
api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
|
||||
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
|
||||
|
||||
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
|
||||
|
||||
@@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return tool_file, 201
|
||||
|
||||
|
||||
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||
|
||||
8
api/controllers/mcp/__init__.py
Normal file
8
api/controllers/mcp/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("mcp", __name__, url_prefix="/mcp")
|
||||
api = ExternalApi(bp)
|
||||
|
||||
from . import mcp
|
||||
104
api/controllers/mcp/mcp.py
Normal file
104
api/controllers/mcp/mcp.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.app.mcp_server import AppMCPServerStatus
|
||||
from controllers.mcp import api
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp import types
|
||||
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
|
||||
|
||||
class MCPAppApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, (int, str)):
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=False, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
request_id = args.get("id")
|
||||
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
|
||||
)
|
||||
|
||||
if server.status != AppMCPServerStatus.ACTIVE:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
|
||||
)
|
||||
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
|
||||
)
|
||||
|
||||
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
|
||||
)
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
converted_user_input_form: list[VariableEntity] = []
|
||||
try:
|
||||
for item in user_input_form:
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
converted_user_input_form.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
)
|
||||
|
||||
try:
|
||||
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
try:
|
||||
notification = ClientNotification.model_validate(args)
|
||||
request = notification
|
||||
except ValidationError as e:
|
||||
return helper.compact_generate_response(
|
||||
create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
|
||||
)
|
||||
|
||||
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.model import App, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -78,20 +78,9 @@ class TextApi(Resource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -211,6 +211,9 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
|
||||
if not indexing_technique:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
@@ -301,6 +304,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
if dataset.provider == "external":
|
||||
raise ValueError("External datasets are not supported.")
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
args["indexing_technique"] = dataset.indexing_technique
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppMode
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
@@ -77,21 +77,9 @@ class TextApi(WebApiResource):
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
if (
|
||||
app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
|
||||
and app_model.workflow
|
||||
and app_model.workflow.features_dict
|
||||
):
|
||||
text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
|
||||
voice = args.get("voice") or text_to_speech.get("voice")
|
||||
else:
|
||||
try:
|
||||
voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
|
||||
except Exception:
|
||||
voice = None
|
||||
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -161,10 +161,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
@@ -254,10 +258,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
prompt_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
@@ -85,7 +85,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: Optional[dict] = None
|
||||
features: Optional[list[AgentFeature]] = None
|
||||
|
||||
meta_version: Optional[str] = None
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@@ -15,10 +15,12 @@ class PluginAgentStrategy(BaseAgentStrategy):
|
||||
|
||||
tenant_id: str
|
||||
declaration: AgentStrategyEntity
|
||||
meta_version: str | None = None
|
||||
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity, meta_version: str | None):
|
||||
self.tenant_id = tenant_id
|
||||
self.declaration = declaration
|
||||
self.meta_version = meta_version
|
||||
|
||||
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||
return self.declaration.parameters
|
||||
|
||||
@@ -113,9 +113,9 @@ class VariableEntity(BaseModel):
|
||||
hide: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@@ -128,16 +128,6 @@ class VariableEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class RagPipelineVariableEntity(VariableEntity):
|
||||
"""
|
||||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: Optional[str] = None
|
||||
placeholder: Optional[str] = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
External Data Variable Entity.
|
||||
@@ -295,7 +285,7 @@ class AppConfig(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
additional_features: AppAdditionalFeatures
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
@@ -20,19 +20,3 @@ class WorkflowVariablesConfigManager:
|
||||
variables.append(VariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@classmethod
|
||||
def convert_rag_pipeline_variable(cls, workflow: Workflow) -> list[RagPipelineVariableEntity]:
|
||||
"""
|
||||
Convert workflow start variables to variables
|
||||
|
||||
:param workflow: workflow instance
|
||||
"""
|
||||
variables = []
|
||||
|
||||
user_input_form = workflow.rag_pipeline_user_input_form()
|
||||
# variables
|
||||
for variable in user_input_form:
|
||||
variables.append(RagPipelineVariableEntity.model_validate(variable))
|
||||
|
||||
return variables
|
||||
|
||||
@@ -43,13 +43,11 @@ from core.app.entities.task_entities import (
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecution
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.datasource.entities import DatasourceNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models import (
|
||||
@@ -185,14 +183,6 @@ class WorkflowResponseConverter:
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
elif event.node_type == NodeType.DATASOURCE:
|
||||
node_data = cast(DatasourceNodeData, event.node_data)
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(
|
||||
self._application_generate_entity.app_config.tenant_id,
|
||||
f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
)
|
||||
response.data.extras["icon"] = provider_entity.declaration.identity.icon
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
AppStreamResponse,
|
||||
ErrorStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
PingStreamResponse,
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return cls.convert_blocking_full_response(blocking_response)
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
:param stream_response: stream response
|
||||
:return:
|
||||
"""
|
||||
for chunk in stream_response:
|
||||
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||
sub_stream_response = chunk.stream_response
|
||||
|
||||
if isinstance(sub_stream_response, PingStreamResponse):
|
||||
yield "ping"
|
||||
continue
|
||||
|
||||
response_chunk = {
|
||||
"event": sub_stream_response.event.value,
|
||||
"workflow_run_id": chunk.workflow_run_id,
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
@@ -1,64 +0,0 @@
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||
from models.dataset import Pipeline
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||
"""
|
||||
Pipeline Config Entity.
|
||||
"""
|
||||
|
||||
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||
pass
|
||||
|
||||
|
||||
class PipelineConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||
pipeline_config = PipelineConfig(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
app_id=pipeline.id,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
workflow_id=workflow.id,
|
||||
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
|
||||
)
|
||||
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
:param tenant_id: tenant id
|
||||
:param config: app model config args
|
||||
:param only_structure_validate: only validate the structure of the config
|
||||
"""
|
||||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# moderation validation
|
||||
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||
)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
related_config_keys = list(set(related_config_keys))
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
@@ -1,621 +0,0 @@
|
||||
import contextvars
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from extensions.ext_database import db
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode
|
||||
from services.dataset_service import DocumentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineGenerator(BaseAppGenerator):
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
)
|
||||
# Add null check for dataset
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
start_node_id: str = args["start_node_id"]
|
||||
datasource_type: str = args["datasource_type"]
|
||||
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||
documents = []
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
for datasource_info in datasource_info_list:
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document = self._build_document(
|
||||
tenant_id=pipeline.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
created_from="rag-pipeline",
|
||||
position=position,
|
||||
account=user,
|
||||
batch=batch,
|
||||
document_form=dataset.chunk_structure,
|
||||
)
|
||||
db.session.add(document)
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
|
||||
# run in child thread
|
||||
for i, datasource_info in enumerate(datasource_info_list):
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
document_id = None
|
||||
if invoke_from == InvokeFrom.PUBLISHED:
|
||||
document_id = documents[i].id
|
||||
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||
document_id=document_id,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=json.dumps(datasource_info),
|
||||
datasource_node_id=start_node_id,
|
||||
input_data=inputs,
|
||||
pipeline_id=pipeline.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
db.session.add(document_pipeline_execution_log)
|
||||
db.session.commit()
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=datasource_type,
|
||||
datasource_info=datasource_info,
|
||||
dataset_id=dataset.id,
|
||||
start_node_id=start_node_id,
|
||||
batch=batch,
|
||||
document_id=document_id,
|
||||
inputs=self._prepare_user_inputs(
|
||||
user_inputs=inputs,
|
||||
variables=pipeline_config.rag_pipeline_variables,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||
),
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
workflow_execution_id=workflow_run_id,
|
||||
)
|
||||
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||
else:
|
||||
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=workflow_triggered_from,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context=contextvars.copy_context(),
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
else:
|
||||
# run in child thread
|
||||
context = contextvars.copy_context()
|
||||
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"pipeline": pipeline,
|
||||
"workflow_id": workflow.id,
|
||||
"user": user,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"invoke_from": invoke_from,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"streaming": streaming,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
# return batch, dataset, documents
|
||||
return {
|
||||
"batch": batch,
|
||||
"dataset": PipelineDataset(
|
||||
id=dataset.id,
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
chunk_structure=dataset.chunk_structure,
|
||||
).model_dump(),
|
||||
"documents": [
|
||||
PipelineDocument(
|
||||
id=document.id,
|
||||
position=document.position,
|
||||
data_source_type=document.data_source_type,
|
||||
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||
name=document.name,
|
||||
indexing_status=document.indexing_status,
|
||||
error=document.error,
|
||||
enabled=document.enabled,
|
||||
).model_dump()
|
||||
for document in documents
|
||||
],
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
*,
|
||||
flask_app: Flask,
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param pipeline: Pipeline
|
||||
:param workflow: Workflow
|
||||
:param user: account or end user
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param workflow_execution_repository: repository for workflow execution
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
app_mode=AppMode.RAG_PIPELINE,
|
||||
)
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(
|
||||
target=self._generate_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"context": context,
|
||||
"queue_manager": queue_manager,
|
||||
"application_generate_entity": application_generate_entity,
|
||||
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
stream=streaming,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||
|
||||
def single_iteration_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
dataset_id=dataset.id,
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow.id,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
self,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
:param app_model: App
|
||||
:param workflow: Workflow
|
||||
:param node_id: the node id
|
||||
:param user: account or end user
|
||||
:param args: request args
|
||||
:param streaming: is streamed
|
||||
"""
|
||||
if not node_id:
|
||||
raise ValueError("node_id is required")
|
||||
|
||||
if args.get("inputs") is None:
|
||||
raise ValueError("inputs is required")
|
||||
|
||||
dataset = pipeline.dataset
|
||||
if not dataset:
|
||||
raise ValueError("Pipeline dataset is required")
|
||||
|
||||
# convert to app config
|
||||
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = RagPipelineGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=pipeline_config,
|
||||
pipeline_config=pipeline_config,
|
||||
datasource_type=args.get("datasource_type", ""),
|
||||
datasource_info=args.get("datasource_info", {}),
|
||||
batch=args.get("batch", ""),
|
||||
document_id=args.get("document_id"),
|
||||
dataset_id=dataset.id,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
extras={"auto_generate_conversation_name": False},
|
||||
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||
workflow_execution_id=str(uuid.uuid4()),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
|
||||
# Create workflow node execution repository
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=application_generate_entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||
)
|
||||
|
||||
return self._generate(
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
pipeline=pipeline,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
try:
|
||||
# workflow app
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
except ValidationError as e:
|
||||
logger.exception("Validation Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except ValueError as e:
|
||||
if dify_config.DEBUG:
|
||||
logger.exception("Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _handle_response(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param workflow: workflow
|
||||
:param queue_manager: queue manager
|
||||
:param user: account or end user
|
||||
:param stream: is stream
|
||||
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||
:return:
|
||||
"""
|
||||
# init generate task pipeline
|
||||
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=stream,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
)
|
||||
|
||||
try:
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(
|
||||
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
def _build_document(
|
||||
self,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
built_in_field_enabled: bool,
|
||||
datasource_type: str,
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
if datasource_type == "local_file":
|
||||
name = datasource_info["name"]
|
||||
elif datasource_type == "online_document":
|
||||
name = datasource_info["page"]["page_name"]
|
||||
elif datasource_type == "website_crawl":
|
||||
name = datasource_info["title"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=position,
|
||||
data_source_type=datasource_type,
|
||||
data_source_info=json.dumps(datasource_info),
|
||||
batch=batch,
|
||||
name=name,
|
||||
created_from=created_from,
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if built_in_field_enabled:
|
||||
doc_metadata = {
|
||||
BuiltInField.document_name: name,
|
||||
BuiltInField.uploader: account.name,
|
||||
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
BuiltInField.source: datasource_type,
|
||||
}
|
||||
if doc_metadata:
|
||||
document.doc_metadata = doc_metadata
|
||||
return document
|
||||
@@ -1,44 +0,0 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueErrorEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
|
||||
|
||||
class PipelineQueueManager(AppQueueManager):
|
||||
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||
super().__init__(task_id, user_id, invoke_from)
|
||||
|
||||
self._app_mode = app_mode
|
||||
|
||||
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
:param pub_from:
|
||||
:return:
|
||||
"""
|
||||
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||
|
||||
self._q.put(message)
|
||||
|
||||
if isinstance(
|
||||
event,
|
||||
QueueStopEvent
|
||||
| QueueErrorEvent
|
||||
| QueueMessageEndEvent
|
||||
| QueueWorkflowSucceededEvent
|
||||
| QueueWorkflowFailedEvent
|
||||
| QueueWorkflowPartialSuccessEvent,
|
||||
):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedError()
|
||||
@@ -1,221 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
RagPipelineGenerateEntity,
|
||||
)
|
||||
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import UserFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
Pipeline Application Runner
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def _get_app_id(self) -> str:
|
||||
return self.application_generate_entity.app_config.app_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
Run application
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(PipelineConfig, app_config)
|
||||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
|
||||
workflow_callbacks: list[WorkflowCallback] = []
|
||||
if dify_config.DEBUG:
|
||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run:
|
||||
# if only single iteration run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||
)
|
||||
elif self.application_generate_entity.single_loop_run:
|
||||
# if only single loop run is requested
|
||||
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||
workflow=workflow,
|
||||
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
files = self.application_generate_entity.files
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = {
|
||||
SystemVariableKey.FILES: files,
|
||||
SystemVariableKey.USER_ID: user_id,
|
||||
SystemVariableKey.APP_ID: app_config.app_id,
|
||||
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||
}
|
||||
rag_pipeline_variables = []
|
||||
if workflow.rag_pipeline_variables:
|
||||
for v in workflow.rag_pipeline_variables:
|
||||
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||
if (
|
||||
rag_pipeline_variable.belong_to_node_id
|
||||
in (self.application_generate_entity.start_node_id, "shared")
|
||||
) and rag_pipeline_variable.variable in inputs:
|
||||
rag_pipeline_variables.append(
|
||||
RAGPipelineVariableInput(
|
||||
variable=rag_pipeline_variable,
|
||||
value=inputs[rag_pipeline_variable.variable],
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=rag_pipeline_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = self._init_rag_pipeline_graph(
|
||||
graph_config=workflow.graph_dict,
|
||||
start_node_id=self.application_generate_entity.start_node_id,
|
||||
)
|
||||
|
||||
# RUN WORKFLOW
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
graph=graph,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=(
|
||||
UserFrom.ACCOUNT
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
else UserFrom.END_USER
|
||||
),
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id,
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||
|
||||
for event in generator:
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
"""
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
nodes = graph_config.get("nodes", [])
|
||||
edges = graph_config.get("edges", [])
|
||||
real_run_nodes = []
|
||||
real_edges = []
|
||||
exclude_node_ids = []
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
node_type = node.get("data", {}).get("type", "")
|
||||
if node_type == "datasource":
|
||||
if start_node_id != node_id:
|
||||
exclude_node_ids.append(node_id)
|
||||
continue
|
||||
real_run_nodes.append(node)
|
||||
for edge in edges:
|
||||
if edge.get("source") in exclude_node_ids:
|
||||
continue
|
||||
real_edges.append(edge)
|
||||
graph_config = dict(graph_config)
|
||||
graph_config["nodes"] = real_run_nodes
|
||||
graph_config["edges"] = real_edges
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
return graph
|
||||
@@ -36,7 +36,6 @@ class InvokeFrom(Enum):
|
||||
# DEBUGGER indicates that this invocation is from
|
||||
# the workflow (or chatflow) edit page.
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED = "published"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@@ -241,38 +240,3 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
"""
|
||||
RAG Pipeline Application Generate Entity.
|
||||
"""
|
||||
|
||||
# pipeline config
|
||||
pipeline_config: WorkflowUIBasedAppConfig
|
||||
datasource_type: str
|
||||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
@@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from models.enums import MessageStatus
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
|
||||
return err
|
||||
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = "error"
|
||||
message.status = MessageStatus.ERROR
|
||||
message.error = err_desc
|
||||
return err
|
||||
|
||||
|
||||
@@ -105,14 +105,6 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||
"""Run on datasource start."""
|
||||
if dify_config.DEBUG:
|
||||
print_text(
|
||||
"\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + str(datasource_inputs) + "\n",
|
||||
color=self.color,
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class DatasourcePlugin(ABC):
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
) -> None:
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
@abstractmethod
|
||||
def datasource_provider_type(self) -> str:
|
||||
"""
|
||||
returns the type of the datasource provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||
return self.__class__(
|
||||
entity=self.entity.model_copy(),
|
||||
runtime=runtime,
|
||||
)
|
||||
@@ -1,118 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class DatasourcePluginProviderController(ABC):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
"""
|
||||
returns whether the provider needs credentials
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
manager = PluginToolManager()
|
||||
if not manager.validate_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
provider=self.entity.identity.name,
|
||||
credentials=credentials,
|
||||
):
|
||||
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
@abstractmethod
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||
for credential_name in credentials_schema:
|
||||
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||
|
||||
for credential_name in credentials:
|
||||
if credential_name not in credentials_need_to_validate:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||
)
|
||||
|
||||
# check type
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if not credential_schema.required and credentials[credential_name] is None:
|
||||
continue
|
||||
|
||||
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||
if not isinstance(credentials[credential_name], str):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||
|
||||
options = credential_schema.options
|
||||
if not isinstance(options, list):
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||
|
||||
if credentials[credential_name] not in [x.value for x in options]:
|
||||
raise ToolProviderCredentialValidationError(
|
||||
f"credential {credential_name} should be one of {options}"
|
||||
)
|
||||
|
||||
credentials_need_to_validate.pop(credential_name)
|
||||
|
||||
for credential_name in credentials_need_to_validate:
|
||||
credential_schema = credentials_need_to_validate[credential_name]
|
||||
if credential_schema.required:
|
||||
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||
|
||||
# the credential is not set currently, set the default value if needed
|
||||
if credential_schema.default is not None:
|
||||
default_value = credential_schema.default
|
||||
# parse default value into the correct type
|
||||
if credential_schema.type in {
|
||||
ProviderConfig.Type.SECRET_INPUT,
|
||||
ProviderConfig.Type.TEXT_INPUT,
|
||||
ProviderConfig.Type.SELECT,
|
||||
}:
|
||||
default_value = str(default_value)
|
||||
|
||||
credentials[credential_name] = default_value
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
|
||||
class DatasourceRuntime(BaseModel):
|
||||
"""
|
||||
Meta data of a datasource call processing
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
datasource_id: Optional[str] = None
|
||||
invoke_from: Optional[InvokeFrom] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||
"""
|
||||
Fake datasource runtime for testing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
tenant_id="fake_tenant_id",
|
||||
datasource_id="fake_datasource_id",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||
credentials={},
|
||||
runtime_parameters={},
|
||||
)
|
||||
@@ -1,244 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileManager:
|
||||
@staticmethod
|
||||
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||
source_url="",
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(upload_file)
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> UploadFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
upload_file = UploadFile(
|
||||
tenant_id=tenant_id,
|
||||
storage_type=dify_config.STORAGE_TYPE,
|
||||
key=filepath,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
extension=extension,
|
||||
mime_type=mimetype,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=user_id,
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||
source_url=file_url,
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(upload_file.key)
|
||||
|
||||
return blob, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(upload_file.key)
|
||||
|
||||
return stream, upload_file.mime_type
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
# from core.file.datasource_file_parser import datasource_file_manager
|
||||
#
|
||||
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||
@@ -1,100 +0,0 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Union
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||
_builtin_providers_loaded = False
|
||||
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
) -> DatasourcePluginProviderController:
|
||||
"""
|
||||
get the datasource plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.datasource_plugin_providers.get()
|
||||
except LookupError:
|
||||
contexts.datasource_plugin_providers.set({})
|
||||
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||
|
||||
with contexts.datasource_plugin_providers_lock.get():
|
||||
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||
if provider_id in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider_id]
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case DatasourceProviderType.LOCAL_FILE:
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=provider_entity.declaration,
|
||||
plugin_id=provider_entity.plugin_id,
|
||||
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_datasource_runtime(
|
||||
cls,
|
||||
provider_id: str,
|
||||
datasource_name: str,
|
||||
tenant_id: str,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> DatasourcePlugin:
|
||||
"""
|
||||
get the datasource runtime
|
||||
|
||||
:param provider_type: the type of the provider
|
||||
:param provider_id: the id of the provider
|
||||
:param datasource_name: the name of the datasource
|
||||
:param tenant_id: the tenant id
|
||||
|
||||
:return: the datasource plugin
|
||||
"""
|
||||
return cls.get_datasource_plugin_provider(
|
||||
provider_id,
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("datasources", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class I18nObject(BaseModel):
|
||||
"""
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.zh_Hans = self.zh_Hans or self.en_US
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
@@ -1,361 +0,0 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.plugin.entities.oauth import OAuthSchema
|
||||
from core.plugin.entities.parameters import (
|
||||
PluginParameter,
|
||||
PluginParameterOption,
|
||||
PluginParameterType,
|
||||
as_normal_type,
|
||||
cast_parameter_value,
|
||||
init_frontend_parameter,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||
|
||||
|
||||
class DatasourceProviderType(enum.StrEnum):
|
||||
"""
|
||||
Enum class for datasource provider
|
||||
"""
|
||||
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
LOCAL_FILE = "local_file"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
ONLINE_DRIVE = "online_drive"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f"invalid mode value {value}")
|
||||
|
||||
|
||||
class DatasourceParameter(PluginParameter):
|
||||
"""
|
||||
Overrides type
|
||||
"""
|
||||
|
||||
class DatasourceParameterType(enum.StrEnum):
|
||||
"""
|
||||
removes TOOLS_SELECTOR from PluginParameterType
|
||||
"""
|
||||
|
||||
STRING = PluginParameterType.STRING.value
|
||||
NUMBER = PluginParameterType.NUMBER.value
|
||||
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||
SELECT = PluginParameterType.SELECT.value
|
||||
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||
FILE = PluginParameterType.FILE.value
|
||||
FILES = PluginParameterType.FILES.value
|
||||
|
||||
# deprecated, should not use.
|
||||
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||
|
||||
def as_normal_type(self):
|
||||
return as_normal_type(self)
|
||||
|
||||
def cast_value(self, value: Any):
|
||||
return cast_parameter_value(self, value)
|
||||
|
||||
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||
description: I18nObject = Field(..., description="The description of the parameter")
|
||||
|
||||
@classmethod
|
||||
def get_simple_instance(
|
||||
cls,
|
||||
name: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
) -> "DatasourceParameter":
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
||||
:param name: the name of the parameter
|
||||
:param llm_description: the description presented to the LLM
|
||||
:param typ: the type of the parameter
|
||||
:param required: if the parameter is required
|
||||
:param options: the options of the parameter
|
||||
"""
|
||||
# convert options to ToolParameterOption
|
||||
# FIXME fix the type error
|
||||
if options:
|
||||
option_objs = [
|
||||
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||
for option in options
|
||||
]
|
||||
else:
|
||||
option_objs = []
|
||||
|
||||
return cls(
|
||||
name=name,
|
||||
label=I18nObject(en_US="", zh_Hans=""),
|
||||
placeholder=None,
|
||||
type=typ,
|
||||
required=required,
|
||||
options=option_objs,
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
)
|
||||
|
||||
def init_frontend_parameter(self, value: Any):
|
||||
return init_frontend_parameter(self, self.type, value)
|
||||
|
||||
|
||||
class DatasourceIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the datasource")
|
||||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||
return v or []
|
||||
|
||||
|
||||
class DatasourceProviderIdentity(BaseModel):
|
||||
author: str = Field(..., description="The author of the tool")
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
||||
|
||||
class DatasourceProviderEntity(BaseModel):
|
||||
"""
|
||||
Datasource provider entity
|
||||
"""
|
||||
|
||||
identity: DatasourceProviderIdentity
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatasourceInvokeMeta(BaseModel):
|
||||
"""
|
||||
Datasource invoke meta
|
||||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an empty instance of DatasourceInvokeMeta
|
||||
"""
|
||||
return cls(time_cost=0.0, error=None, tool_config={})
|
||||
|
||||
@classmethod
|
||||
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||
"""
|
||||
Get an instance of DatasourceInvokeMeta with error
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
"tool_config": self.tool_config,
|
||||
}
|
||||
|
||||
|
||||
class DatasourceLabel(BaseModel):
|
||||
"""
|
||||
Datasource label
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="The name of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
|
||||
|
||||
class DatasourceInvokeFrom(Enum):
|
||||
"""
|
||||
Enum class for datasource invoke
|
||||
"""
|
||||
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class OnlineDocumentPage(BaseModel):
|
||||
"""
|
||||
Online document page
|
||||
"""
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
"""
|
||||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
workspace_name: str = Field(..., description="The workspace name")
|
||||
workspace_icon: str = Field(..., description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
||||
class OnlineDocumentPagesMessage(BaseModel):
|
||||
"""
|
||||
Get online document pages response
|
||||
"""
|
||||
|
||||
result: list[OnlineDocumentInfo]
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||
"""
|
||||
Get online document page content request
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
|
||||
|
||||
class OnlineDocumentPageContent(BaseModel):
|
||||
"""
|
||||
Online document page content
|
||||
"""
|
||||
|
||||
workspace_id: str = Field(..., description="The workspace id")
|
||||
page_id: str = Field(..., description="The page id")
|
||||
content: str = Field(..., description="The content of the page")
|
||||
|
||||
|
||||
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||
"""
|
||||
Get online document page content response
|
||||
"""
|
||||
|
||||
result: OnlineDocumentPageContent
|
||||
|
||||
|
||||
class GetWebsiteCrawlRequest(BaseModel):
|
||||
"""
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
source_url: str = Field(..., description="The url of the website")
|
||||
content: str = Field(..., description="The content of the website")
|
||||
title: str = Field(..., description="The title of the website")
|
||||
description: str = Field(..., description="The description of the website")
|
||||
|
||||
|
||||
class WebSiteInfo(BaseModel):
|
||||
"""
|
||||
Website info
|
||||
"""
|
||||
|
||||
status: Optional[str] = Field(..., description="crawl job status")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
"""
|
||||
Get website crawl response
|
||||
"""
|
||||
|
||||
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||
|
||||
|
||||
class DatasourceMessage(ToolInvokeMessage):
|
||||
pass
|
||||
|
||||
|
||||
#########################
|
||||
# Online driver file
|
||||
#########################
|
||||
|
||||
|
||||
class OnlineDriveFile(BaseModel):
|
||||
"""
|
||||
Online driver file
|
||||
"""
|
||||
|
||||
key: str = Field(..., description="The key of the file")
|
||||
size: int = Field(..., description="The size of the file")
|
||||
|
||||
|
||||
class OnlineDriveFileBucket(BaseModel):
|
||||
"""
|
||||
Online driver file bucket
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The bucket of the file")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
|
||||
is_truncated: bool = Field(False, description="Whether the bucket has more files")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
"""
|
||||
Get online driver file list request
|
||||
"""
|
||||
|
||||
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
|
||||
bucket: Optional[str] = Field(None, description="Storage bucket name")
|
||||
max_keys: int = Field(20, description="Maximum number of files to return")
|
||||
start_after: Optional[str] = Field(
|
||||
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
|
||||
)
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
"""
|
||||
Get online driver file list response
|
||||
"""
|
||||
|
||||
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
Get online driver file
|
||||
"""
|
||||
|
||||
key: str = Field(..., description="The name of the file")
|
||||
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||
@@ -1,37 +0,0 @@
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
|
||||
|
||||
class DatasourceProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DatasourceEngineInvokeError(Exception):
|
||||
meta: DatasourceInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
super().__init__(**kwargs)
|
||||
@@ -1,28 +0,0 @@
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
|
||||
|
||||
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
@@ -1,56 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return LocalFileDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDocumentPagesMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_online_document_pages(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_pages(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def get_online_document_page_content(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_online_document_page_content(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
@@ -1,48 +0,0 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDocumentDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def online_drive_browse_files(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveBrowseFilesRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_browse_files(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def online_drive_download_file(
|
||||
self,
|
||||
user_id: str,
|
||||
request: OnlineDriveDownloadFileRequest,
|
||||
provider_type: str,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.online_drive_download_file(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
@@ -1,48 +0,0 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
|
||||
|
||||
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return OnlineDriveDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -1,265 +0,0 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
@@ -1,121 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasourceFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_datasource_invoke_messages(
|
||||
cls,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
"""
|
||||
for message in messages:
|
||||
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||
yield message
|
||||
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||
message.message, DatasourceMessage.TextMessage
|
||||
):
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
file = DatasourceFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(
|
||||
text=f"Failed to download image: {message.message.text}: {e}"
|
||||
),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
meta = message.meta or {}
|
||||
|
||||
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||
# get filename from meta
|
||||
filename = meta.get("file_name", None)
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||
raise ValueError("unexpected message type")
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message.blob,
|
||||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||
if file.type == FileType.IMAGE:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||
@@ -1,389 +0,0 @@
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
from yaml import YAMLError, safe_load # type: ignore
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
request_env = request.headers.get("X-Request-Env")
|
||||
if request_env:
|
||||
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||
server_url = matched_servers[0] if matched_servers else server_url
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
placeholder=I18nObject(
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
tool_parameter.type = typ
|
||||
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
for parameter in parameters:
|
||||
if parameter.name not in parameters_count:
|
||||
parameters_count[parameter.name] = 0
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ in {"integer", "number"}:
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
elif typ == "array":
|
||||
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except JSONDecodeError:
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:param extra_info: the extra info
|
||||
:param warning: the warning message
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
content = content.strip()
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
json_error = e
|
||||
|
||||
if loaded_content is None:
|
||||
try:
|
||||
loaded_content = safe_load(content)
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||
f" yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
import re
|
||||
|
||||
|
||||
def remove_leading_symbols(text: str) -> str:
|
||||
"""
|
||||
Remove leading punctuation or symbols from the given text.
|
||||
|
||||
Args:
|
||||
text (str): The input text to process.
|
||||
|
||||
Returns:
|
||||
str: The text with leading punctuation or symbols removed.
|
||||
"""
|
||||
# Match Unicode ranges for punctuation and symbols
|
||||
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||
return re.sub(pattern, "", text)
|
||||
@@ -1,9 +0,0 @@
|
||||
import uuid
|
||||
|
||||
|
||||
def is_valid_uuid(uuid_str: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(uuid_str)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,43 +0,0 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
class WorkflowToolConfigurationUtils:
|
||||
@classmethod
|
||||
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||
for configuration in configurations:
|
||||
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
):
|
||||
"""
|
||||
check is synced
|
||||
|
||||
raise ValueError if not synced
|
||||
"""
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
@@ -1,35 +0,0 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml # type: ignore
|
||||
from yaml import YAMLError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
:param ignore_error:
|
||||
if True, return default_value if error occurs and the error will be logged in debug level
|
||||
if False, raise error if error occurs
|
||||
:param default_value: the value returned when errors ignored
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
if not file_path or not Path(file_path).exists():
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content or default_value
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
else:
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||
@@ -1,53 +0,0 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||
tenant_id: str
|
||||
icon: str
|
||||
plugin_unique_identifier: str
|
||||
entity: DatasourceEntity
|
||||
runtime: DatasourceRuntime
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceEntity,
|
||||
runtime: DatasourceRuntime,
|
||||
tenant_id: str,
|
||||
icon: str,
|
||||
plugin_unique_identifier: str,
|
||||
) -> None:
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
def get_website_crawl(
|
||||
self,
|
||||
user_id: str,
|
||||
datasource_parameters: Mapping[str, Any],
|
||||
provider_type: str,
|
||||
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||
manager = PluginDatasourceManager()
|
||||
|
||||
return manager.get_website_crawl(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider=self.entity.identity.provider,
|
||||
datasource_name=self.entity.identity.name,
|
||||
credentials=self.runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
@@ -1,52 +0,0 @@
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
entity: DatasourceProviderEntityWithPlugin
|
||||
plugin_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity: DatasourceProviderEntityWithPlugin,
|
||||
plugin_id: str,
|
||||
plugin_unique_identifier: str,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
super().__init__(entity, tenant_id)
|
||||
self.plugin_id = plugin_id
|
||||
self.plugin_unique_identifier = plugin_unique_identifier
|
||||
|
||||
@property
|
||||
def provider_type(self) -> DatasourceProviderType:
|
||||
"""
|
||||
returns the type of the provider
|
||||
"""
|
||||
return DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||
"""
|
||||
return datasource with given name
|
||||
"""
|
||||
datasource_entity = next(
|
||||
(
|
||||
datasource_entity
|
||||
for datasource_entity in self.entity.datasources
|
||||
if datasource_entity.identity.name == datasource_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not datasource_entity:
|
||||
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||
|
||||
return WebsiteCrawlDatasourcePlugin(
|
||||
entity=datasource_entity,
|
||||
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||
tenant_id=self.tenant_id,
|
||||
icon=self.entity.identity.icon,
|
||||
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||
)
|
||||
@@ -17,27 +17,3 @@ class IndexingEstimate(BaseModel):
|
||||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
chunk_structure: str
|
||||
|
||||
|
||||
class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: Optional[dict] = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: Optional[str] = None
|
||||
enabled: bool
|
||||
|
||||
|
||||
class PipelineGenerateResponse(BaseModel):
|
||||
batch: str
|
||||
dataset: PipelineDataset
|
||||
documents: list[PipelineDocument]
|
||||
|
||||
@@ -21,6 +21,9 @@ class CommonParameterType(StrEnum):
|
||||
DYNAMIC_SELECT = "dynamic-select"
|
||||
|
||||
# TOOL_SELECTOR = "tool-selector"
|
||||
# MCP object and array type parameters
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class AppSelectorScope(StrEnum):
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.datasource import datasource_file_manager
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class DatasourceFileParser:
|
||||
@staticmethod
|
||||
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||
@@ -20,7 +20,6 @@ class FileTransferMethod(StrEnum):
|
||||
REMOTE_URL = "remote_url"
|
||||
LOCAL_FILE = "local_file"
|
||||
TOOL_FILE = "tool_file"
|
||||
DATASOURCE_FILE = "datasource_file"
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
|
||||
@@ -51,7 +51,7 @@ class File(BaseModel):
|
||||
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from constants import IMAGE_EXTENSIONS
|
||||
from core.helper.url_signer import UrlSigner
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file.id)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.exception(f"File not found: {upload_file.key}")
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file_id: the id of UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
||||
|
||||
return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
result = UrlSigner.verify(
|
||||
sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
|
||||
)
|
||||
|
||||
# verify signature
|
||||
if not result:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
@@ -28,7 +28,7 @@ class TemplateTransformer(ABC):
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
|
||||
return result.group(1)
|
||||
|
||||
@classmethod
|
||||
@@ -38,16 +38,53 @@ class TemplateTransformer(ABC):
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
result_str = cls.extract_result_str_from_response(response)
|
||||
result = json.loads(result_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
|
||||
except ValueError as e:
|
||||
# Re-raise ValueError from extract_result_str_from_response
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
|
||||
|
||||
# Check if the result contains an error
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
raise ValueError(f"JavaScript execution error: {result['error']}")
|
||||
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
raise ValueError("Result keys must be strings")
|
||||
|
||||
# Post-process the result to convert scientific notation strings back to numbers
|
||||
result = cls._post_process_result(result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
|
||||
"""
|
||||
Post-process the result to convert scientific notation strings back to numbers
|
||||
"""
|
||||
|
||||
def convert_scientific_notation(value):
|
||||
if isinstance(value, str):
|
||||
# Check if the string looks like scientific notation
|
||||
if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_scientific_notation(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [convert_scientific_notation(v) for v in value]
|
||||
return value
|
||||
|
||||
return convert_scientific_notation(result) # type: ignore[no-any-return]
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_runner_script(cls) -> str:
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Any, Any] = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key: Any) -> Any:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
else:
|
||||
self.cache.move_to_end(key) # move the key to the end of the OrderedDict
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: Any, value: Any) -> None:
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
self.cache[key] = value
|
||||
if len(self.cache) > self.capacity:
|
||||
self.cache.popitem(last=False) # pop the first item
|
||||
@@ -317,9 +317,10 @@ class IndexingRunner:
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
if image_file:
|
||||
storage.delete(image_file.key)
|
||||
storage.delete(image_file.key)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"Delete image_files failed while indexing_estimate, \
|
||||
|
||||
@@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
|
||||
|
||||
@@ -170,10 +171,15 @@ def invoke_llm_with_structured_output(
|
||||
system_fingerprint: Optional[str] = None
|
||||
for event in llm_result:
|
||||
if isinstance(event, LLMResultChunk):
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
|
||||
if isinstance(event.delta.message.content, str):
|
||||
result_text += event.delta.message.content
|
||||
prompt_messages = event.prompt_messages
|
||||
system_fingerprint = event.system_fingerprint
|
||||
elif isinstance(event.delta.message.content, list):
|
||||
for item in event.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
result_text += item.data
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
model=model_schema.model,
|
||||
|
||||
342
api/core/mcp/auth/auth_flow.py
Normal file
342
api/core/mcp/auth/auth_flow.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
|
||||
OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
|
||||
|
||||
|
||||
class OAuthCallbackState(BaseModel):
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
server_url: str
|
||||
metadata: OAuthMetadata | None = None
|
||||
client_information: OAuthClientInformation
|
||||
code_verifier: str
|
||||
redirect_uri: str
|
||||
|
||||
|
||||
def generate_pkce_challenge() -> tuple[str, str]:
|
||||
"""Generate PKCE challenge and verifier."""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
|
||||
code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
|
||||
code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
|
||||
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
"""Create a secure state parameter by storing state data in Redis and returning a random state key."""
|
||||
# Generate a secure random state key
|
||||
state_key = secrets.token_urlsafe(32)
|
||||
|
||||
# Store the state data in Redis with expiration
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
|
||||
|
||||
return state_key
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
state_data = redis_client.get(redis_key)
|
||||
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
full_state_data.client_information,
|
||||
authorization_code,
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
|
||||
def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
|
||||
"""Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
|
||||
url = urljoin(server_url, "/.well-known/oauth-authorization-server")
|
||||
|
||||
try:
|
||||
headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
except requests.RequestException as e:
|
||||
if isinstance(e, requests.ConnectionError):
|
||||
response = requests.get(url)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
|
||||
return OAuthMetadata.model_validate(response.json())
|
||||
raise
|
||||
|
||||
|
||||
def start_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
redirect_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Begins the authorization flow with secure Redis state storage."""
|
||||
response_type = "code"
|
||||
code_challenge_method = "S256"
|
||||
|
||||
if metadata:
|
||||
authorization_url = metadata.authorization_endpoint
|
||||
if response_type not in metadata.response_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
|
||||
if (
|
||||
not metadata.code_challenge_methods_supported
|
||||
or code_challenge_method not in metadata.code_challenge_methods_supported
|
||||
):
|
||||
raise ValueError(
|
||||
f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
|
||||
)
|
||||
else:
|
||||
authorization_url = urljoin(server_url, "/authorize")
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_challenge()
|
||||
|
||||
# Prepare state data with all necessary information
|
||||
state_data = OAuthCallbackState(
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
server_url=server_url,
|
||||
metadata=metadata,
|
||||
client_information=client_information,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=redirect_url,
|
||||
)
|
||||
|
||||
# Store state data in Redis and generate secure state key
|
||||
state_key = _create_secure_redis_state(state_data)
|
||||
|
||||
params = {
|
||||
"response_type": response_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"redirect_uri": redirect_url,
|
||||
"state": state_key,
|
||||
}
|
||||
|
||||
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
|
||||
return authorization_url, code_verifier
|
||||
|
||||
|
||||
def exchange_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
authorization_code: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchanges an authorization code for an access token."""
|
||||
grant_type = "authorization_code"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"code": authorization_code,
|
||||
"code_verifier": code_verifier,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.model_validate(response.json())
|
||||
|
||||
|
||||
def refresh_authorization(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_information: OAuthClientInformation,
|
||||
refresh_token: str,
|
||||
) -> OAuthTokens:
|
||||
"""Exchange a refresh token for an updated access token."""
|
||||
grant_type = "refresh_token"
|
||||
|
||||
if metadata:
|
||||
token_url = metadata.token_endpoint
|
||||
if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
|
||||
raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
|
||||
else:
|
||||
token_url = urljoin(server_url, "/token")
|
||||
|
||||
params = {
|
||||
"grant_type": grant_type,
|
||||
"client_id": client_information.client_id,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
if client_information.client_secret:
|
||||
params["client_secret"] = client_information.client_secret
|
||||
|
||||
response = requests.post(token_url, data=params)
|
||||
if not response.ok:
|
||||
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
|
||||
return OAuthTokens.parse_obj(response.json())
|
||||
|
||||
|
||||
def register_client(
|
||||
server_url: str,
|
||||
metadata: Optional[OAuthMetadata],
|
||||
client_metadata: OAuthClientMetadata,
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Performs OAuth 2.0 Dynamic Client Registration."""
|
||||
if metadata:
|
||||
if not metadata.registration_endpoint:
|
||||
raise ValueError("Incompatible auth server: does not support dynamic client registration")
|
||||
registration_url = metadata.registration_endpoint
|
||||
else:
|
||||
registration_url = urljoin(server_url, "/register")
|
||||
|
||||
response = requests.post(
|
||||
registration_url,
|
||||
json=client_metadata.model_dump(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if not response.ok:
|
||||
response.raise_for_status()
|
||||
return OAuthClientInformationFull.model_validate(response.json())
|
||||
|
||||
|
||||
def auth(
|
||||
provider: OAuthClientProvider,
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
|
||||
# Handle client registration if needed
|
||||
client_information = provider.client_information()
|
||||
if not client_information:
|
||||
if authorization_code is not None:
|
||||
raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
|
||||
try:
|
||||
full_information = register_client(server_url, metadata, provider.client_metadata)
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Could not register OAuth client: {e}")
|
||||
provider.save_client_information(full_information)
|
||||
client_information = full_information
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
if authorization_code is not None:
|
||||
if not state_param:
|
||||
raise ValueError("State parameter is required when exchanging authorization code")
|
||||
|
||||
try:
|
||||
# Retrieve state data from Redis using state key
|
||||
full_state_data = _retrieve_redis_state(state_param)
|
||||
|
||||
code_verifier = full_state_data.code_verifier
|
||||
redirect_uri = full_state_data.redirect_uri
|
||||
|
||||
if not code_verifier or not redirect_uri:
|
||||
raise ValueError("Missing code_verifier or redirect_uri in state data")
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise ValueError(f"Invalid state parameter: {e}")
|
||||
|
||||
tokens = exchange_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
authorization_code,
|
||||
code_verifier,
|
||||
redirect_uri,
|
||||
)
|
||||
provider.save_tokens(tokens)
|
||||
return {"result": "success"}
|
||||
|
||||
provider_tokens = provider.tokens()
|
||||
|
||||
# Handle token refresh or new authorization
|
||||
if provider_tokens and provider_tokens.refresh_token:
|
||||
try:
|
||||
new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
|
||||
provider.save_tokens(new_tokens)
|
||||
return {"result": "success"}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not refresh OAuth tokens: {e}")
|
||||
|
||||
# Start new authorization flow
|
||||
authorization_url, code_verifier = start_authorization(
|
||||
server_url,
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
return {"authorization_url": authorization_url}
|
||||
81
api/core/mcp/auth/auth_provider.py
Normal file
81
api/core/mcp/auth/auth_provider.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import (
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
"""The URL to redirect the user agent to after authorization."""
|
||||
return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
|
||||
|
||||
@property
|
||||
def client_metadata(self) -> OAuthClientMetadata:
|
||||
"""Metadata about this OAuth client."""
|
||||
return OAuthClientMetadata(
|
||||
redirect_uris=[self.redirect_url],
|
||||
token_endpoint_auth_method="none",
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
client_name="Dify",
|
||||
client_uri="https://github.com/langgenius/dify",
|
||||
)
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
access_token=credentials.get("access_token", ""),
|
||||
token_type=credentials.get("token_type", "Bearer"),
|
||||
expires_in=int(credentials.get("expires_in", "3600") or 3600),
|
||||
refresh_token=credentials.get("refresh_token", ""),
|
||||
)
|
||||
|
||||
def save_tokens(self, tokens: OAuthTokens) -> None:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
361
api/core/mcp/client/sse_client.py
Normal file
361
api/core/mcp/client/sse_client.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> None:
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
||||
"""
|
||||
Send a message to the server using the provided HTTP client.
|
||||
|
||||
Args:
|
||||
http_client: The HTTP client to use for sending
|
||||
endpoint_url: The endpoint URL to send the message to
|
||||
session_message: The message to send
|
||||
"""
|
||||
try:
|
||||
response = http_client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error sending message")
|
||||
raise
|
||||
|
||||
|
||||
def read_messages(
|
||||
sse_client: SSEClient,
|
||||
) -> Generator[SessionMessage | Exception, None, None]:
|
||||
"""
|
||||
Read messages from the SSE client.
|
||||
|
||||
Args:
|
||||
sse_client: The SSE client to read from
|
||||
|
||||
Yields:
|
||||
SessionMessage or Exception for each event received
|
||||
"""
|
||||
try:
|
||||
for sse in sse_client.events():
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
yield SessionMessage(message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
yield exc
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error reading SSE messages")
|
||||
yield exc
|
||||
476
api/core/mcp/client/streamable_client.py
Normal file
476
api/core/mcp/client/streamable_client.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception | None
|
||||
# Queue types with clearer names for their roles
|
||||
ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
|
||||
ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "Accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.Client
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
server_to_client_queue: ServerToClientQueue # Renamed for clarity
|
||||
sse_read_timeout: timedelta
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.session_id: str | None = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], None] | None = None,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
# Put message in queue that goes to client
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc:
|
||||
# Put exception in queue that goes to client
|
||||
server_to_client_queue.put(exc)
|
||||
return False
|
||||
elif sse.event == "ping":
|
||||
logger.debug("Received ping event")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
def handle_get_stream(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
|
||||
client=client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, server_to_client_queue)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}")
|
||||
|
||||
def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token")
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
with ssrf_proxy_sse_connect(
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
|
||||
client=ctx.client,
|
||||
method="GET",
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
|
||||
def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._update_headers_with_session(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404:
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
self._send_session_terminated_error(
|
||||
ctx.server_to_client_queue,
|
||||
message.root.id,
|
||||
)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
|
||||
|
||||
if content_type.startswith(JSON):
|
||||
self._handle_json_response(response, ctx.server_to_client_queue)
|
||||
elif content_type.startswith(SSE):
|
||||
self._handle_sse_response(response, ctx)
|
||||
else:
|
||||
self._handle_unexpected_content_type(
|
||||
content_type,
|
||||
ctx.server_to_client_queue,
|
||||
)
|
||||
|
||||
def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = response.read()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
session_message = SessionMessage(message)
|
||||
server_to_client_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
for sse in event_source.iter_sse():
|
||||
is_complete = self._handle_sse_event(
|
||||
sse,
|
||||
ctx.server_to_client_queue,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
)
|
||||
if is_complete:
|
||||
break
|
||||
except Exception as e:
|
||||
ctx.server_to_client_queue.put(e)
|
||||
|
||||
def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
) -> None:
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}"
|
||||
logger.error(error_msg)
|
||||
server_to_client_queue.put(ValueError(error_msg))
|
||||
|
||||
def _send_session_terminated_error(
|
||||
self,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated by server"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
server_to_client_queue.put(session_message)
|
||||
|
||||
def post_writer(
|
||||
self,
|
||||
client: httpx.Client,
|
||||
client_to_server_queue: ClientToServerQueue,
|
||||
server_to_client_queue: ServerToClientQueue,
|
||||
start_get_stream: Callable[[], None],
|
||||
) -> None:
|
||||
"""Handle writing requests to the server.
|
||||
|
||||
This method processes messages from the client_to_server_queue and sends them to the server.
|
||||
Responses are written to the server_to_client_queue.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Read message from client queue with timeout to check stop_event periodically
|
||||
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if session_message is None:
|
||||
break
|
||||
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
server_to_client_queue=server_to_client_queue, # Queue to write responses to client
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
if is_resumption:
|
||||
self._handle_resumption_request(ctx)
|
||||
else:
|
||||
self._handle_post_request(ctx)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as exc:
|
||||
server_to_client_queue.put(exc)
|
||||
|
||||
def terminate_session(self, client: httpx.Client) -> None:
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._update_headers_with_session(self.request_headers)
|
||||
response = client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code != 200:
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@contextmanager
|
||||
def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: timedelta = timedelta(seconds=30),
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
|
||||
terminate_on_close: bool = True,
|
||||
) -> Generator[
|
||||
tuple[
|
||||
ServerToClientQueue, # Queue for receiving messages FROM server
|
||||
ClientToServerQueue, # Queue for sending messages TO server
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- server_to_client_queue: Queue for reading messages FROM the server
|
||||
- client_to_server_queue: Queue for sending messages TO the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
# Create queues with clear directional meaning
|
||||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream() -> None:
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
19
api/core/mcp/entities.py
Normal file
19
api/core/mcp/entities.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from core.mcp.session.base_session import BaseSession
|
||||
from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
10
api/core/mcp/error.py
Normal file
10
api/core/mcp/error.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class MCPError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPConnectionError(MCPError):
|
||||
pass
|
||||
|
||||
|
||||
class MCPAuthError(MCPConnectionError):
|
||||
pass
|
||||
150
api/core/mcp/mcp_client.py
Normal file
150
api/core/mcp/mcp_client.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from core.mcp.client.sse_client import sse_client
|
||||
from core.mcp.client.streamable_client import streamablehttp_client
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.session.client_session import ClientSession
|
||||
from core.mcp.types import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
self.client_type = "streamable"
|
||||
self.server_url = server_url
|
||||
|
||||
# Authentication info
|
||||
self.authed = authed
|
||||
self.authorization_code = authorization_code
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._streams_context: Optional[AbstractContextManager[Any]] = None
|
||||
self._session_context: Optional[ClientSession] = None
|
||||
self.exit_stack = ExitStack()
|
||||
|
||||
# Whether the client has been initialized
|
||||
self._initialized = False
|
||||
|
||||
def __enter__(self):
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
|
||||
):
|
||||
self.cleanup()
|
||||
|
||||
def _initialize(
|
||||
self,
|
||||
):
|
||||
"""Initialize the client with fallback to SSE if streamable connection fails"""
|
||||
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
|
||||
"mcp": streamablehttp_client,
|
||||
"sse": sse_client,
|
||||
}
|
||||
|
||||
parsed_url = urlparse(self.server_url)
|
||||
path = parsed_url.path
|
||||
method_name = path.rstrip("/").split("/")[-1] if path else ""
|
||||
try:
|
||||
client_factory = connection_methods[method_name]
|
||||
self.connect_server(client_factory, method_name)
|
||||
except KeyError:
|
||||
try:
|
||||
self.connect_server(sse_client, "sse")
|
||||
except MCPConnectionError:
|
||||
self.connect_server(streamablehttp_client, "mcp")
|
||||
|
||||
def connect_server(
|
||||
self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
|
||||
):
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
|
||||
try:
|
||||
headers = (
|
||||
{"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
|
||||
if self.authed and self.token
|
||||
else {}
|
||||
)
|
||||
self._streams_context = client_factory(url=self.server_url, headers=headers)
|
||||
if self._streams_context is None:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
|
||||
except MCPAuthError:
|
||||
if not self.authed:
|
||||
raise
|
||||
try:
|
||||
auth(self.provider, self.server_url, self.authorization_code)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to authenticate: {e}")
|
||||
self.token = self.provider.tokens()
|
||||
if first_try:
|
||||
return self.connect_server(client_factory, method_name, first_try=False)
|
||||
|
||||
except MCPConnectionError:
|
||||
raise
|
||||
|
||||
def list_tools(self) -> list[Tool]:
|
||||
"""Connect to an MCP server running with SSE transport"""
|
||||
# List available tools to verify connection
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
response = self._session.list_tools()
|
||||
tools = response.tools
|
||||
return tools
|
||||
|
||||
def invoke_tool(self, tool_name: str, tool_args: dict):
|
||||
"""Call a tool"""
|
||||
if not self._initialized or not self._session:
|
||||
raise ValueError("Session not initialized.")
|
||||
return self._session.call_tool(tool_name, tool_args)
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
224
api/core/mcp/server/streamable_http.py
Normal file
224
api/core/mcp/server/streamable_http.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.mcp.utils import create_mcp_error_response
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPServerStreamableHTTPRequestHandler:
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result | str):
|
||||
if isinstance(response, str):
|
||||
sse_content = f"event: ping\ndata: {response}\n\n".encode()
|
||||
yield sse_content
|
||||
return
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
|
||||
return create_mcp_error_response(request_id, code, message, data)
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
types.InitializedNotification: self.handle_notification,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
logger.exception("Invalid params")
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Internal server error")
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
|
||||
def handle_notification(self):
|
||||
return "ping"
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
clinet_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=clinet_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.app.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
if self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
elif self.app.mode in {AppMode.COMPLETION.value}:
|
||||
args = {"query": "", "inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(
|
||||
self.app,
|
||||
self.end_user,
|
||||
args,
|
||||
InvokeFrom.SERVICE_API,
|
||||
streaming=self.app.mode == AppMode.AGENT_CHAT.value,
|
||||
)
|
||||
answer = ""
|
||||
if isinstance(response, RateLimitGenerator):
|
||||
for item in response.generator:
|
||||
data = item
|
||||
if isinstance(data, str) and data.startswith("data: "):
|
||||
try:
|
||||
json_str = data[6:].strip()
|
||||
parsed_data = json.loads(json_str)
|
||||
if parsed_data.get("event") == "agent_thought":
|
||||
answer += parsed_data.get("thought", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if isinstance(response, Mapping):
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
|
||||
parameters: dict[str, dict[str, Any]] = {}
|
||||
required = []
|
||||
for item in user_input_form:
|
||||
parameters[item.variable] = {}
|
||||
if item.type in (
|
||||
VariableEntityType.FILE,
|
||||
VariableEntityType.FILE_LIST,
|
||||
VariableEntityType.EXTERNAL_DATA_TOOL,
|
||||
):
|
||||
continue
|
||||
if item.required:
|
||||
required.append(item.variable)
|
||||
# if the workflow republished, the parameters not changed
|
||||
# we should not raise error here
|
||||
try:
|
||||
description = self.mcp_server.parameters_dict[item.variable]
|
||||
except KeyError:
|
||||
description = ""
|
||||
parameters[item.variable]["description"] = description
|
||||
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
parameters[item.variable]["type"] = "string"
|
||||
elif item.type == VariableEntityType.SELECT:
|
||||
parameters[item.variable]["type"] = "string"
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "float"
|
||||
return parameters, required
|
||||
397
api/core/mcp/session/base_session.py
Normal file
397
api/core/mcp/session/base_session.py
Normal file
@@ -0,0 +1,397 @@
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import ExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import (
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
MessageMetadata,
|
||||
RequestId,
|
||||
RequestParams,
|
||||
ServerMessageMetadata,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SessionMessage,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: Any
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed:
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
|
||||
def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
self._completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered:
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is a context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_receive_request_type: type[ReceiveRequestT]
|
||||
_receive_notification_type: type[ReceiveNotificationT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: queue.Queue,
|
||||
write_stream: queue.Queue,
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._exit_stack = ExitStack()
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self._executor = ThreadPoolExecutor()
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
def check_receiver_status(self) -> None:
|
||||
if self._receiver_future.done():
|
||||
self._receiver_future.result()
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
self._exit_stack.close()
|
||||
self._read_stream.put(None)
|
||||
self._write_stream.put(None)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
will take precedence over the session read timeout.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
|
||||
self._response_streams[request_id] = response_queue
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
|
||||
self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
timeout = DEFAULT_RESPONSE_READ_TIMEOUT
|
||||
if request_read_timeout_seconds is not None:
|
||||
timeout = float(request_read_timeout_seconds.total_seconds())
|
||||
elif self._session_read_timeout_seconds is not None:
|
||||
timeout = float(self._session_read_timeout_seconds.total_seconds())
|
||||
while True:
|
||||
try:
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(
|
||||
code=500,
|
||||
message="No response received",
|
||||
)
|
||||
)
|
||||
elif isinstance(response_or_error, JSONRPCError):
|
||||
if response_or_error.error.code == 401:
|
||||
raise MCPAuthError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
raise MCPConnectionError(
|
||||
ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
|
||||
)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
|
||||
def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
self._write_stream.put(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _receive_loop(self) -> None:
|
||||
"""
|
||||
Main message processing loop.
|
||||
In a real synchronous implementation, this would likely run in a separate thread.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Attempt to receive a message (this would be blocking in a synchronous context)
|
||||
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, HTTPStatusError):
|
||||
response_queue = self._response_streams.get(self._request_id - 1)
|
||||
if response_queue is not None:
|
||||
response_queue.put(
|
||||
JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=self._request_id - 1,
|
||||
error=ErrorData(code=message.response.status_code, message=message.args[0]),
|
||||
)
|
||||
)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
elif isinstance(message, Exception):
|
||||
self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
|
||||
self._in_flight[responder.request_id] = responder
|
||||
self._received_request(responder)
|
||||
|
||||
if not responder._completed:
|
||||
self._handle_incoming(responder)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
|
||||
else: # Response or error
|
||||
response_queue = self._response_streams.get(message.message.root.id)
|
||||
if response_queue is not None:
|
||||
response_queue.put(message.message.root)
|
||||
else:
|
||||
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
pass
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass
|
||||
365
api/core/mcp/session/client_session.py
Normal file
365
api/core/mcp/session/client_session.py
Normal file
@@ -0,0 +1,365 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
raise ValueError(str(message))
|
||||
elif isinstance(message, (types.ServerNotification | RequestResponder)):
|
||||
pass
|
||||
|
||||
|
||||
def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream,
|
||||
write_stream,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
|
||||
def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability()
|
||||
roots = types.RootsCapability(
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
listChanged=True,
|
||||
)
|
||||
|
||||
result = self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
method="initialize",
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self.send_notification(
|
||||
types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.PingRequest(
|
||||
method="ping",
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def send_progress_notification(
|
||||
self, progress_token: str | int, progress: float, total: float | None = None
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
method="notifications/progress",
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
method="logging/setLevel",
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def list_resources(self) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourcesRequest(
|
||||
method="resources/list",
|
||||
)
|
||||
),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
def list_resource_templates(self) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListResourceTemplatesRequest(
|
||||
method="resources/templates/list",
|
||||
)
|
||||
),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
method="resources/read",
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
method="resources/subscribe",
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
method="resources/unsubscribe",
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request."""
|
||||
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
method="tools/call",
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
|
||||
def list_prompts(self) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListPromptsRequest(
|
||||
method="prompts/list",
|
||||
)
|
||||
),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
method="prompts/get",
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
def complete(
|
||||
self,
|
||||
ref: types.ResourceReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
method="completion/complete",
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
def list_tools(self) -> types.ListToolsResult:
|
||||
"""Send a tools/list request."""
|
||||
return self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListToolsRequest(
|
||||
method="tools/list",
|
||||
)
|
||||
),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
def send_roots_list_changed(self) -> None:
|
||||
"""Send a roots/list_changed notification."""
|
||||
self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.RootsListChangedNotification(
|
||||
method="notifications/roots/list_changed",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
list_roots_response = self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(list_roots_response)
|
||||
responder.respond(client_response)
|
||||
|
||||
case types.PingRequest():
|
||||
with responder:
|
||||
return responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
self._message_handler(req)
|
||||
|
||||
def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
1217
api/core/mcp/types.py
Normal file
1217
api/core/mcp/types.py
Normal file
File diff suppressed because it is too large
Load Diff
114
api/core/mcp/utils.py
Normal file
114
api/core/mcp/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.mcp.types import ErrorData, JSONRPCError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
|
||||
"""Create MCP error response"""
|
||||
error_data = ErrorData(code=code, message=message, data=data)
|
||||
json_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id or 1,
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
yield sse_content
|
||||
@@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
|
||||
latency=0.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_metadata(cls, metadata: dict) -> "LLMUsage":
|
||||
"""
|
||||
Create LLMUsage instance from metadata dictionary with default values.
|
||||
|
||||
Args:
|
||||
metadata: Dictionary containing usage metadata
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from metadata or defaults
|
||||
"""
|
||||
total_tokens = metadata.get("total_tokens", 0)
|
||||
completion_tokens = metadata.get("completion_tokens", 0)
|
||||
if total_tokens > 0 and completion_tokens == 0:
|
||||
completion_tokens = total_tokens
|
||||
|
||||
return cls(
|
||||
prompt_tokens=metadata.get("prompt_tokens", 0),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
|
||||
completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
|
||||
total_price=Decimal(str(metadata.get("total_price", 0))),
|
||||
currency=metadata.get("currency", "USD"),
|
||||
prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
|
||||
completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
|
||||
prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
|
||||
completion_price=Decimal(str(metadata.get("completion_price", 0))),
|
||||
latency=metadata.get("latency", 0.0),
|
||||
)
|
||||
|
||||
def plus(self, other: "LLMUsage") -> "LLMUsage":
|
||||
"""
|
||||
Add two LLMUsage instances together.
|
||||
|
||||
487
api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
487
api/core/ops/aliyun_trace/aliyun_trace.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
TraceClient,
|
||||
convert_datetime_to_nanoseconds,
|
||||
convert_to_span_id,
|
||||
convert_to_trace_id,
|
||||
generate_span_id,
|
||||
)
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
from core.ops.aliyun_trace.entities.semconv import (
|
||||
GEN_AI_COMPLETION,
|
||||
GEN_AI_FRAMEWORK,
|
||||
GEN_AI_MODEL_NAME,
|
||||
GEN_AI_PROMPT,
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE,
|
||||
GEN_AI_RESPONSE_FINISH_REASON,
|
||||
GEN_AI_SESSION_ID,
|
||||
GEN_AI_SPAN_KIND,
|
||||
GEN_AI_SYSTEM,
|
||||
GEN_AI_USAGE_INPUT_TOKENS,
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS,
|
||||
GEN_AI_USAGE_TOTAL_TOKENS,
|
||||
GEN_AI_USER_ID,
|
||||
INPUT_VALUE,
|
||||
OUTPUT_VALUE,
|
||||
RETRIEVAL_DOCUMENT,
|
||||
RETRIEVAL_QUERY,
|
||||
TOOL_DESCRIPTION,
|
||||
TOOL_NAME,
|
||||
TOOL_PARAMETERS,
|
||||
GenAISpanKind,
|
||||
)
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import AliyunConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.rag.models.document import Document
|
||||
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.nodes import NodeType
|
||||
from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AliyunDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
aliyun_config: AliyunConfig,
|
||||
):
|
||||
super().__init__(aliyun_config)
|
||||
base_url = aliyun_config.endpoint.rstrip("/")
|
||||
endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
|
||||
self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
pass
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
pass
|
||||
|
||||
def api_check(self):
|
||||
return self.trace_client.api_check()
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
return self.trace_client.get_project_url()
|
||||
except Exception as e:
|
||||
logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Aliyun get run url failed: {str(e)}")
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = convert_to_trace_id(trace_info.workflow_run_id)
|
||||
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
|
||||
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
|
||||
|
||||
workflow_node_executions = self.get_workflow_node_executions(trace_info)
|
||||
for node_execution in workflow_node_executions:
|
||||
node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
|
||||
self.trace_client.add_span(node_span)
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
message_data = trace_info.message_data
|
||||
if message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
user_id = message_data.from_account_id
|
||||
if message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
user_id = end_user_data.session_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
trace_id = convert_to_trace_id(message_id)
|
||||
message_span_id = convert_to_span_id(message_id, "message")
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
app_model_config = getattr(trace_info.message_data, "app_model_config", {})
|
||||
pre_prompt = getattr(app_model_config, "pre_prompt", "")
|
||||
inputs_data = getattr(trace_info.message_data, "inputs", {})
|
||||
llm_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=convert_to_span_id(message_id, "llm"),
|
||||
name="llm",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(trace_info.outputs),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(llm_span)
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
documents_data = extract_retrieval_documents(trace_info.documents)
|
||||
dataset_retrieval_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name="dataset_retrieval",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: str(trace_info.inputs),
|
||||
RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
|
||||
INPUT_VALUE: str(trace_info.inputs),
|
||||
OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
self.trace_client.add_span(dataset_retrieval_span)
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
message_id = trace_info.message_id
|
||||
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
|
||||
tool_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=generate_span_id(),
|
||||
name=trace_info.tool_name,
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: trace_info.tool_name,
|
||||
TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(trace_info.tool_outputs),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(tool_span)
|
||||
|
||||
def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
|
||||
# through workflow_run_id get all_nodes_execution using repository
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
# Find the app's creator account
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).filter(Account.id == app.created_by).first()
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
service_account.set_tenant_id(current_tenant.tenant_id)
|
||||
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory,
|
||||
user=service_account,
|
||||
app_id=trace_info.metadata.get("app_id"),
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
# Get all executions for this workflow run
|
||||
workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
|
||||
workflow_run_id=trace_info.workflow_run_id
|
||||
)
|
||||
return workflow_node_executions
|
||||
|
||||
def build_workflow_node_span(
|
||||
self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
|
||||
):
|
||||
try:
|
||||
if node_execution.node_type == NodeType.LLM:
|
||||
node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
elif node_execution.node_type == NodeType.TOOL:
|
||||
node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
else:
|
||||
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
|
||||
return node_span
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
|
||||
span_status: Status = Status(StatusCode.UNSET)
|
||||
if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
span_status = Status(StatusCode.OK)
|
||||
elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
|
||||
span_status = Status(StatusCode.ERROR, str(node_execution.error))
|
||||
return span_status
|
||||
|
||||
def build_workflow_task_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_tool_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
tool_des = {}
|
||||
if node_execution.metadata:
|
||||
tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
TOOL_NAME: node_execution.title,
|
||||
TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
|
||||
TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_retrieval_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
input_value = ""
|
||||
if node_execution.inputs:
|
||||
input_value = str(node_execution.inputs.get("query", ""))
|
||||
output_value = ""
|
||||
if node_execution.outputs:
|
||||
output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
RETRIEVAL_QUERY: input_value,
|
||||
RETRIEVAL_DOCUMENT: output_value,
|
||||
INPUT_VALUE: input_value,
|
||||
OUTPUT_VALUE: output_value,
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def build_workflow_llm_span(
|
||||
self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
|
||||
) -> SpanData:
|
||||
process_data = node_execution.process_data or {}
|
||||
outputs = node_execution.outputs or {}
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
return SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=workflow_span_id,
|
||||
span_id=convert_to_span_id(node_execution.id, "node"),
|
||||
name=node_execution.title,
|
||||
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
|
||||
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
|
||||
GEN_AI_SYSTEM: process_data.get("model_provider", ""),
|
||||
GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
|
||||
GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
|
||||
GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: str(outputs.get("text", "")),
|
||||
GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
|
||||
INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
|
||||
OUTPUT_VALUE: str(outputs.get("text", "")),
|
||||
},
|
||||
status=self.get_workflow_node_status(node_execution),
|
||||
)
|
||||
|
||||
def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
|
||||
message_span_id = None
|
||||
if trace_info.message_id:
|
||||
message_span_id = convert_to_span_id(trace_info.message_id, "message")
|
||||
user_id = trace_info.metadata.get("user_id")
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
if message_span_id: # chatflow
|
||||
message_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=None,
|
||||
span_id=message_span_id,
|
||||
name="message",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(message_span)
|
||||
|
||||
workflow_span = SpanData(
|
||||
trace_id=trace_id,
|
||||
parent_span_id=message_span_id,
|
||||
span_id=workflow_span_id,
|
||||
name="workflow",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_USER_ID: str(user_id),
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(workflow_span)
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
message_id = trace_info.message_id
|
||||
status: Status = Status(StatusCode.OK)
|
||||
if trace_info.error:
|
||||
status = Status(StatusCode.ERROR, trace_info.error)
|
||||
suggested_question_span = SpanData(
|
||||
trace_id=convert_to_trace_id(message_id),
|
||||
parent_span_id=convert_to_span_id(message_id, "message"),
|
||||
span_id=convert_to_span_id(message_id, "suggested_question"),
|
||||
name="suggested_question",
|
||||
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
|
||||
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
|
||||
attributes={
|
||||
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
|
||||
GEN_AI_FRAMEWORK: "dify",
|
||||
GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
|
||||
GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
|
||||
GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
},
|
||||
status=status,
|
||||
)
|
||||
self.trace_client.add_span(suggested_question_span)
|
||||
|
||||
|
||||
def extract_retrieval_documents(documents: list[Document]):
|
||||
documents_data = []
|
||||
for document in documents:
|
||||
document_data = {
|
||||
"content": document.page_content,
|
||||
"metadata": {
|
||||
"dataset_id": document.metadata.get("dataset_id"),
|
||||
"doc_id": document.metadata.get("doc_id"),
|
||||
"document_id": document.metadata.get("document_id"),
|
||||
},
|
||||
"score": document.metadata.get("score"),
|
||||
}
|
||||
documents_data.append(document_data)
|
||||
return documents_data
|
||||
200
api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
200
api/core/ops/aliyun_trace/data_exporter/traceclient.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
|
||||
|
||||
INVALID_SPAN_ID = 0x0000000000000000
|
||||
INVALID_TRACE_ID = 0x00000000000000000000000000000000
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TraceClient:
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
endpoint: str,
|
||||
max_queue_size: int = 1000,
|
||||
schedule_delay_sec: int = 5,
|
||||
max_export_batch_size: int = 50,
|
||||
):
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
self.span_builder = SpanBuilder(self.resource)
|
||||
self.exporter = OTLPSpanExporter(endpoint=endpoint)
|
||||
|
||||
self.max_queue_size = max_queue_size
|
||||
self.schedule_delay_sec = schedule_delay_sec
|
||||
self.max_export_batch_size = max_export_batch_size
|
||||
|
||||
self.queue: deque = deque(maxlen=max_queue_size)
|
||||
self.condition = threading.Condition(threading.Lock())
|
||||
self.done = False
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
self._spans_dropped = False
|
||||
|
||||
def export(self, spans: Sequence[ReadableSpan]):
|
||||
self.exporter.export(spans)
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
response = requests.head(self.endpoint, timeout=5)
|
||||
if response.status_code == 405:
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
|
||||
return False
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.debug(f"AliyunTrace API check failed: {str(e)}")
|
||||
raise ValueError(f"AliyunTrace API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
return "https://arms.console.aliyun.com/#/llm"
|
||||
|
||||
def add_span(self, span_data: SpanData):
|
||||
if span_data is None:
|
||||
return
|
||||
span: ReadableSpan = self.span_builder.build_span(span_data)
|
||||
with self.condition:
|
||||
if len(self.queue) == self.max_queue_size:
|
||||
if not self._spans_dropped:
|
||||
logger.warning("Queue is full, likely spans will be dropped.")
|
||||
self._spans_dropped = True
|
||||
|
||||
self.queue.appendleft(span)
|
||||
if len(self.queue) >= self.max_export_batch_size:
|
||||
self.condition.notify()
|
||||
|
||||
def _worker(self):
|
||||
while not self.done:
|
||||
with self.condition:
|
||||
if len(self.queue) < self.max_export_batch_size and not self.done:
|
||||
self.condition.wait(timeout=self.schedule_delay_sec)
|
||||
self._export_batch()
|
||||
|
||||
def _export_batch(self):
|
||||
spans_to_export: list[ReadableSpan] = []
|
||||
with self.condition:
|
||||
while len(spans_to_export) < self.max_export_batch_size and self.queue:
|
||||
spans_to_export.append(self.queue.pop())
|
||||
|
||||
if spans_to_export:
|
||||
try:
|
||||
self.exporter.export(spans_to_export)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error exporting spans: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
with self.condition:
|
||||
self.done = True
|
||||
self.condition.notify_all()
|
||||
self.worker_thread.join()
|
||||
self._export_batch()
|
||||
self.exporter.shutdown()
|
||||
|
||||
|
||||
class SpanBuilder:
|
||||
def __init__(self, resource):
|
||||
self.resource = resource
|
||||
self.instrumentation_scope = InstrumentationScope(
|
||||
__name__,
|
||||
"",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
def build_span(self, span_data: SpanData) -> ReadableSpan:
|
||||
span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
parent_span_context = None
|
||||
if span_data.parent_span_id is not None:
|
||||
parent_span_context = trace_api.SpanContext(
|
||||
trace_id=span_data.trace_id,
|
||||
span_id=span_data.parent_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
|
||||
trace_state=None,
|
||||
)
|
||||
|
||||
span = ReadableSpan(
|
||||
name=span_data.name,
|
||||
context=span_context,
|
||||
parent=parent_span_context,
|
||||
resource=self.resource,
|
||||
attributes=span_data.attributes,
|
||||
events=span_data.events,
|
||||
links=span_data.links,
|
||||
kind=trace_api.SpanKind.INTERNAL,
|
||||
status=span_data.status,
|
||||
start_time=span_data.start_time,
|
||||
end_time=span_data.end_time,
|
||||
instrumentation_scope=self.instrumentation_scope,
|
||||
)
|
||||
return span
|
||||
|
||||
|
||||
def generate_span_id() -> int:
|
||||
span_id = random.getrandbits(64)
|
||||
while span_id == INVALID_SPAN_ID:
|
||||
span_id = random.getrandbits(64)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
return uuid_obj.int
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
|
||||
|
||||
def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_v4)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid UUID input: {e}")
|
||||
combined_key = f"{uuid_obj.hex}-{span_type}"
|
||||
hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
|
||||
span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
|
||||
return span_id
|
||||
|
||||
|
||||
def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
|
||||
if start_time_a is None:
|
||||
return None
|
||||
timestamp_in_seconds = start_time_a.timestamp()
|
||||
timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
|
||||
return timestamp_in_nanoseconds
|
||||
21
api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
Normal file
21
api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import trace as trace_api
|
||||
from opentelemetry.sdk.trace import Event, Status, StatusCode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SpanData(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
trace_id: int = Field(..., description="The unique identifier for the trace.")
|
||||
parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
|
||||
span_id: int = Field(..., description="The unique identifier for this span.")
|
||||
name: str = Field(..., description="The name of the span.")
|
||||
attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
|
||||
events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
|
||||
links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
|
||||
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
|
||||
start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
|
||||
end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")
|
||||
64
api/core/ops/aliyun_trace/entities/semconv.py
Normal file
64
api/core/ops/aliyun_trace/entities/semconv.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from enum import Enum
|
||||
|
||||
# public
|
||||
GEN_AI_SESSION_ID = "gen_ai.session.id"
|
||||
|
||||
GEN_AI_USER_ID = "gen_ai.user.id"
|
||||
|
||||
GEN_AI_USER_NAME = "gen_ai.user.name"
|
||||
|
||||
GEN_AI_SPAN_KIND = "gen_ai.span.kind"
|
||||
|
||||
GEN_AI_FRAMEWORK = "gen_ai.framework"
|
||||
|
||||
|
||||
# Chain
|
||||
INPUT_VALUE = "input.value"
|
||||
|
||||
OUTPUT_VALUE = "output.value"
|
||||
|
||||
|
||||
# Retriever
|
||||
RETRIEVAL_QUERY = "retrieval.query"
|
||||
|
||||
RETRIEVAL_DOCUMENT = "retrieval.document"
|
||||
|
||||
|
||||
# LLM
|
||||
GEN_AI_MODEL_NAME = "gen_ai.model_name"
|
||||
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
|
||||
GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
|
||||
|
||||
GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
|
||||
|
||||
GEN_AI_PROMPT = "gen_ai.prompt"
|
||||
|
||||
GEN_AI_COMPLETION = "gen_ai.completion"
|
||||
|
||||
GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
|
||||
|
||||
# Tool
|
||||
TOOL_NAME = "tool.name"
|
||||
|
||||
TOOL_DESCRIPTION = "tool.description"
|
||||
|
||||
TOOL_PARAMETERS = "tool.parameters"
|
||||
|
||||
|
||||
class GenAISpanKind(Enum):
|
||||
CHAIN = "CHAIN"
|
||||
RETRIEVER = "RETRIEVER"
|
||||
RERANKER = "RERANKER"
|
||||
LLM = "LLM"
|
||||
EMBEDDING = "EMBEDDING"
|
||||
TOOL = "TOOL"
|
||||
AGENT = "AGENT"
|
||||
TASK = "TASK"
|
||||
726
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
Normal file
726
api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
Normal file
@@ -0,0 +1,726 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
|
||||
from opentelemetry.trace import SpanContext, TraceFlags, TraceState
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
|
||||
from core.ops.entities.trace_entity import (
|
||||
BaseTraceInfo,
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
|
||||
"""Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
|
||||
try:
|
||||
# Choose the appropriate exporter based on config type
|
||||
exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
|
||||
if isinstance(arize_phoenix_config, ArizeConfig):
|
||||
arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
|
||||
arize_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"space_id": arize_phoenix_config.space_id or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = GrpcOTLPSpanExporter(
|
||||
endpoint=arize_endpoint,
|
||||
headers=arize_headers,
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
|
||||
phoenix_headers = {
|
||||
"api_key": arize_phoenix_config.api_key or "",
|
||||
"authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
|
||||
}
|
||||
exporter = HttpOTLPSpanExporter(
|
||||
endpoint=phoenix_endpoint,
|
||||
headers=phoenix_headers,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
attributes = {
|
||||
"openinference.project.name": arize_phoenix_config.project or "",
|
||||
"model_id": arize_phoenix_config.project or "",
|
||||
}
|
||||
resource = Resource(attributes=attributes)
|
||||
provider = trace_sdk.TracerProvider(resource=resource)
|
||||
processor = SimpleSpanProcessor(
|
||||
exporter,
|
||||
)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
# Create a named tracer instead of setting the global provider
|
||||
tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
|
||||
logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
|
||||
return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def datetime_to_nanos(dt: Optional[datetime]) -> int:
|
||||
"""Convert datetime to nanoseconds since epoch. If None, use current time."""
|
||||
if dt is None:
|
||||
dt = datetime.now()
|
||||
return int(dt.timestamp() * 1_000_000_000)
|
||||
|
||||
|
||||
def uuid_to_trace_id(string: Optional[str]) -> int:
|
||||
"""Convert UUID string to a valid trace ID (16-byte integer)."""
|
||||
if string is None:
|
||||
string = ""
|
||||
hash_object = hashlib.sha256(string.encode())
|
||||
|
||||
# Take the first 16 bytes (128 bits) of the hash
|
||||
digest = hash_object.digest()[:16]
|
||||
|
||||
# Convert to integer (128 bits)
|
||||
return int.from_bytes(digest, byteorder="big")
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
arize_phoenix_config: ArizeConfig | PhoenixConfig,
|
||||
):
|
||||
super().__init__(arize_phoenix_config)
|
||||
import logging
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
self.arize_phoenix_config = arize_phoenix_config
|
||||
self.tracer, self.processor = setup_tracer(arize_phoenix_config)
|
||||
self.project = arize_phoenix_config.project
|
||||
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
|
||||
|
||||
def trace(self, trace_info: BaseTraceInfo):
|
||||
logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
|
||||
try:
|
||||
if isinstance(trace_info, WorkflowTraceInfo):
|
||||
self.workflow_trace(trace_info)
|
||||
if isinstance(trace_info, MessageTraceInfo):
|
||||
self.message_trace(trace_info)
|
||||
if isinstance(trace_info, ModerationTraceInfo):
|
||||
self.moderation_trace(trace_info)
|
||||
if isinstance(trace_info, SuggestedQuestionTraceInfo):
|
||||
self.suggested_question_trace(trace_info)
|
||||
if isinstance(trace_info, DatasetRetrievalTraceInfo):
|
||||
self.dataset_retrieval_trace(trace_info)
|
||||
if isinstance(trace_info, ToolTraceInfo):
|
||||
self.tool_trace(trace_info)
|
||||
if isinstance(trace_info, GenerateNameTraceInfo):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
workflow_metadata = {
|
||||
"workflow_id": trace_info.workflow_run_id or "",
|
||||
"message_id": trace_info.message_id or "",
|
||||
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
|
||||
"status": trace_info.workflow_run_status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
}
|
||||
workflow_metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
workflow_span = self.tracer.start_span(
|
||||
name=TraceTaskName.WORKFLOW_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
# Process workflow nodes
|
||||
for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
|
||||
created_at = node_execution.created_at or datetime.now()
|
||||
elapsed_time = node_execution.elapsed_time
|
||||
finished_at = created_at + timedelta(seconds=elapsed_time)
|
||||
|
||||
process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
|
||||
|
||||
node_metadata = {
|
||||
"node_id": node_execution.id,
|
||||
"node_type": node_execution.node_type,
|
||||
"node_status": node_execution.status,
|
||||
"tenant_id": node_execution.tenant_id,
|
||||
"app_id": node_execution.app_id,
|
||||
"app_name": node_execution.title,
|
||||
"status": node_execution.status,
|
||||
"level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
|
||||
}
|
||||
|
||||
if node_execution.execution_metadata:
|
||||
node_metadata.update(json.loads(node_execution.execution_metadata))
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM.value
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_metadata["ls_provider"] = provider
|
||||
if model:
|
||||
node_metadata["ls_model_name"] = model
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
if usage_data:
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL.value
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN.value
|
||||
|
||||
node_span = self.tracer.start_span(
|
||||
name=node_execution.node_type,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
|
||||
SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
|
||||
SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
|
||||
},
|
||||
start_time=datetime_to_nanos(created_at),
|
||||
)
|
||||
|
||||
try:
|
||||
if node_execution.node_type == "llm":
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
|
||||
if model:
|
||||
node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
|
||||
|
||||
outputs = json.loads(node_execution.outputs).get("usage", {})
|
||||
usage_data = (
|
||||
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
)
|
||||
if usage_data:
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
|
||||
)
|
||||
node_span.set_attribute(
|
||||
SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
|
||||
)
|
||||
finally:
|
||||
node_span.end(end_time=datetime_to_nanos(finished_at))
|
||||
finally:
|
||||
workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def message_trace(self, trace_info: MessageTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
file_list = cast(list[str], trace_info.file_list) or []
|
||||
message_file_data: Optional[MessageFile] = trace_info.message_file_data
|
||||
|
||||
if message_file_data is not None:
|
||||
file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
|
||||
file_list.append(file_url)
|
||||
|
||||
message_metadata = {
|
||||
"message_id": trace_info.message_id or "",
|
||||
"conversation_mode": str(trace_info.conversation_mode or ""),
|
||||
"user_id": trace_info.message_data.from_account_id or "",
|
||||
"file_list": json.dumps(file_list),
|
||||
"status": trace_info.message_data.status or "",
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens or 0,
|
||||
"prompt_tokens": trace_info.message_tokens or 0,
|
||||
"completion_tokens": trace_info.answer_tokens or 0,
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
message_metadata.update(trace_info.metadata)
|
||||
|
||||
# Add end user data if available
|
||||
if trace_info.message_data.from_end_user_id:
|
||||
end_user_data: Optional[EndUser] = (
|
||||
db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
|
||||
)
|
||||
if end_user_data is not None:
|
||||
message_metadata["end_user_id"] = end_user_data.session_id
|
||||
|
||||
attributes = {
|
||||
SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
message_span_id = RandomIdGenerator().generate_span_id()
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=message_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
message_span = self.tracer.start_span(
|
||||
name=TraceTaskName.MESSAGE_TRACE.value,
|
||||
attributes=attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
message_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
|
||||
# Convert outputs to string based on type
|
||||
if isinstance(trace_info.outputs, dict | list):
|
||||
outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
|
||||
elif isinstance(trace_info.outputs, str):
|
||||
outputs_str = trace_info.outputs
|
||||
else:
|
||||
outputs_str = str(trace_info.outputs)
|
||||
|
||||
llm_attributes = {
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: outputs_str,
|
||||
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
}
|
||||
|
||||
if isinstance(trace_info.inputs, list):
|
||||
for i, msg in enumerate(trace_info.inputs):
|
||||
if isinstance(msg, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
|
||||
"role", "user"
|
||||
)
|
||||
# todo: handle assistant and tool role messages, as they don't always
|
||||
# have a text field, but may have a tool_calls field instead
|
||||
# e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
|
||||
# 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
|
||||
elif isinstance(trace_info.inputs, dict):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
elif isinstance(trace_info.inputs, str):
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
|
||||
llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
|
||||
|
||||
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
|
||||
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
|
||||
if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
|
||||
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
|
||||
|
||||
if trace_info.message_data.model_id is not None:
|
||||
llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
|
||||
if trace_info.message_data.model_provider is not None:
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
llm_span = self.tracer.start_span(
|
||||
name="llm",
|
||||
attributes=llm_attributes,
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
llm_span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
finally:
|
||||
message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def moderation_trace(self, trace_info: ModerationTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "moderation",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.MODERATION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(
|
||||
{
|
||||
"action": trace_info.action,
|
||||
"flagged": trace_info.flagged,
|
||||
"preset_response": trace_info.preset_response,
|
||||
"inputs": trace_info.inputs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "suggested_question",
|
||||
"status": trace_info.status,
|
||||
"status_message": trace_info.error or "",
|
||||
"level": "ERROR" if trace_info.error else "DEFAULT",
|
||||
"total_tokens": trace_info.total_tokens,
|
||||
"ls_provider": trace_info.model_provider or "",
|
||||
"ls_model_name": trace_info.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
start_time = trace_info.start_time or trace_info.message_data.created_at
|
||||
end_time = trace_info.end_time or trace_info.message_data.updated_at
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_name": "dataset_retrieval",
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
"ls_provider": trace_info.message_data.model_provider or "",
|
||||
"ls_model_name": trace_info.message_data.model_id or "",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
"start_time": start_time.isoformat() if start_time else "",
|
||||
"end_time": end_time.isoformat() if end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(end_time))
|
||||
|
||||
def tool_trace(self, trace_info: ToolTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"message_id": trace_info.message_id,
|
||||
"tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
|
||||
}
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
tool_span_id = RandomIdGenerator().generate_span_id()
|
||||
logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
|
||||
|
||||
# Create span context with the same trace_id as the parent
|
||||
# todo: Create with the appropriate parent span context, so that the tool span is
|
||||
# a child of the appropriate span (e.g. message span)
|
||||
span_context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=tool_span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
tool_params_str = (
|
||||
json.dumps(trace_info.tool_parameters, ensure_ascii=False)
|
||||
if isinstance(trace_info.tool_parameters, dict)
|
||||
else str(trace_info.tool_parameters)
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=trace_info.tool_name,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.TOOL_NAME: trace_info.tool_name,
|
||||
SpanAttributes.TOOL_PARAMETERS: tool_params_str,
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
|
||||
if trace_info.message_data is None:
|
||||
return
|
||||
|
||||
metadata = {
|
||||
"project_name": self.project,
|
||||
"message_id": trace_info.message_id,
|
||||
"status": trace_info.message_data.status,
|
||||
"status_message": trace_info.message_data.error or "",
|
||||
"level": "ERROR" if trace_info.message_data.error else "DEFAULT",
|
||||
}
|
||||
metadata.update(trace_info.metadata)
|
||||
|
||||
trace_id = uuid_to_trace_id(trace_info.message_id)
|
||||
span_id = RandomIdGenerator().generate_span_id()
|
||||
context = SpanContext(
|
||||
trace_id=trace_id,
|
||||
span_id=span_id,
|
||||
is_remote=False,
|
||||
trace_flags=TraceFlags(TraceFlags.SAMPLED),
|
||||
trace_state=TraceState(),
|
||||
)
|
||||
|
||||
span = self.tracer.start_span(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE.value,
|
||||
attributes={
|
||||
SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
|
||||
SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
|
||||
SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
|
||||
SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
|
||||
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
|
||||
"start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
|
||||
"end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
|
||||
},
|
||||
start_time=datetime_to_nanos(trace_info.start_time),
|
||||
context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
|
||||
)
|
||||
|
||||
try:
|
||||
if trace_info.message_data.error:
|
||||
span.add_event(
|
||||
"exception",
|
||||
attributes={
|
||||
"exception.message": trace_info.message_data.error,
|
||||
"exception.type": "Error",
|
||||
"exception.stacktrace": trace_info.message_data.error,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
span.end(end_time=datetime_to_nanos(trace_info.end_time))
|
||||
|
||||
def api_check(self):
|
||||
try:
|
||||
with self.tracer.start_span("api_check") as span:
|
||||
span.set_attribute("test", "true")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
|
||||
|
||||
def get_project_url(self):
|
||||
try:
|
||||
if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
|
||||
return "https://app.arize.com/"
|
||||
else:
|
||||
return f"{self.arize_phoenix_config.endpoint}/projects/"
|
||||
except Exception as e:
|
||||
logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
|
||||
raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
|
||||
|
||||
def _get_workflow_nodes(self, workflow_run_id: str):
|
||||
"""Helper method to get workflow nodes"""
|
||||
workflow_nodes = (
|
||||
db.session.query(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.inputs,
|
||||
WorkflowNodeExecutionModel.outputs,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.process_data,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
|
||||
.all()
|
||||
)
|
||||
return workflow_nodes
|
||||
@@ -2,20 +2,92 @@ from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ValidationInfo, field_validator
|
||||
|
||||
from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
|
||||
|
||||
|
||||
class TracingProviderEnum(StrEnum):
|
||||
ARIZE = "arize"
|
||||
PHOENIX = "phoenix"
|
||||
LANGFUSE = "langfuse"
|
||||
LANGSMITH = "langsmith"
|
||||
OPIK = "opik"
|
||||
WEAVE = "weave"
|
||||
ALIYUN = "aliyun"
|
||||
|
||||
|
||||
class BaseTracingConfig(BaseModel):
|
||||
"""
|
||||
Base model class for tracing
|
||||
Base model class for tracing configurations
|
||||
"""
|
||||
|
||||
...
|
||||
@classmethod
|
||||
def validate_endpoint_url(cls, v: str, default_url: str) -> str:
|
||||
"""
|
||||
Common endpoint URL validation logic
|
||||
|
||||
Args:
|
||||
v: URL value to validate
|
||||
default_url: Default URL to use if input is None or empty
|
||||
|
||||
Returns:
|
||||
Validated and normalized URL
|
||||
"""
|
||||
return validate_url(v, default_url)
|
||||
|
||||
@classmethod
|
||||
def validate_project_field(cls, v: str, default_name: str) -> str:
|
||||
"""
|
||||
Common project name validation logic
|
||||
|
||||
Args:
|
||||
v: Project name to validate
|
||||
default_name: Default name to use if input is None or empty
|
||||
|
||||
Returns:
|
||||
Validated project name
|
||||
"""
|
||||
return validate_project_name(v, default_name)
|
||||
|
||||
|
||||
class ArizeConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Arize tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
space_id: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://otlp.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://otlp.arize.com")
|
||||
|
||||
|
||||
class PhoenixConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Phoenix tracing config.
|
||||
"""
|
||||
|
||||
api_key: str | None = None
|
||||
project: str | None = None
|
||||
endpoint: str = "https://app.phoenix.arize.com"
|
||||
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "default")
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
|
||||
|
||||
|
||||
class LangfuseConfig(BaseTracingConfig):
|
||||
@@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.langfuse.com"
|
||||
if not v.startswith("https://") and not v.startswith("http://"):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
|
||||
return v
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://api.langfuse.com")
|
||||
|
||||
|
||||
class LangSmithConfig(BaseTracingConfig):
|
||||
@@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.smith.langchain.com"
|
||||
if not v.startswith("https://"):
|
||||
raise ValueError("endpoint must start with https://")
|
||||
|
||||
return v
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# LangSmith only allows HTTPS
|
||||
return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
|
||||
|
||||
|
||||
class OpikConfig(BaseTracingConfig):
|
||||
@@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
|
||||
@field_validator("project")
|
||||
@classmethod
|
||||
def project_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "Default Project"
|
||||
|
||||
return v
|
||||
return cls.validate_project_field(v, "Default Project")
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def url_validator(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://www.comet.com/opik/api/"
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("url must start with https:// or http://")
|
||||
if not v.endswith("/api/"):
|
||||
raise ValueError("url should ends with /api/")
|
||||
|
||||
return v
|
||||
return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
|
||||
|
||||
|
||||
class WeaveConfig(BaseTracingConfig):
|
||||
@@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://trace.wandb.ai"
|
||||
if not v.startswith("https://"):
|
||||
raise ValueError("endpoint must start with https://")
|
||||
|
||||
return v
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
# Weave only allows HTTPS for endpoint
|
||||
return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def validate_host(cls, v, info: ValidationInfo):
|
||||
if v is not None and v != "":
|
||||
if not v.startswith(("https://", "http://")):
|
||||
raise ValueError("host must start with https:// or http://")
|
||||
def host_validator(cls, v, info: ValidationInfo):
|
||||
if v is not None and v.strip() != "":
|
||||
return validate_url(v, v, allowed_schemes=("https", "http"))
|
||||
return v
|
||||
|
||||
|
||||
class AliyunConfig(BaseTracingConfig):
|
||||
"""
|
||||
Model class for Aliyun tracing config.
|
||||
"""
|
||||
|
||||
app_name: str = "dify_app"
|
||||
license_key: str
|
||||
endpoint: str
|
||||
|
||||
@field_validator("app_name")
|
||||
@classmethod
|
||||
def app_name_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_project_field(v, "dify_app")
|
||||
|
||||
@field_validator("license_key")
|
||||
@classmethod
|
||||
def license_key_validator(cls, v, info: ValidationInfo):
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("License key cannot be empty")
|
||||
return v
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def endpoint_validator(cls, v, info: ValidationInfo):
|
||||
return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
|
||||
|
||||
|
||||
OPS_FILE_PATH = "ops_trace/"
|
||||
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
|
||||
|
||||
@@ -135,4 +135,3 @@ class TraceTaskName(StrEnum):
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
|
||||
@@ -32,6 +32,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -180,12 +181,9 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
try:
|
||||
if outputs.get("usage"):
|
||||
prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
|
||||
else:
|
||||
prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
|
||||
completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
|
||||
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
|
||||
prompt_tokens = usage_data.get("prompt_tokens", 0)
|
||||
completion_tokens = usage_data.get("completion_tokens", 0)
|
||||
except Exception:
|
||||
logger.error("Failed to extract usage", exc_info=True)
|
||||
|
||||
@@ -293,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
input=trace_info.inputs,
|
||||
output=message_data.answer,
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
@@ -339,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=trace_info.metadata,
|
||||
level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
|
||||
level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
|
||||
status_message=message_data.error or "",
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user