Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
62fed068a9 test: enable async leak detection in vitest config
Co-authored-by: hyoban <38493346+hyoban@users.noreply.github.com>
2026-03-13 14:55:07 +00:00
copilot-swe-agent[bot]
a10aeeed59 Initial plan 2026-03-13 14:45:57 +00:00
346 changed files with 3280 additions and 6090 deletions

View File

@@ -27,7 +27,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@@ -113,7 +113,7 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*

View File

@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: "3.12"
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: "3.12"

View File

@@ -28,7 +28,7 @@ jobs:
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true

View File

@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: false
python-version: "3.12"

View File

@@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -77,7 +77,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Download blob reports
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: web/.vitest-reports
pattern: blob-report-*

3
.gitignore vendored
View File

@@ -237,6 +237,3 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md
# Code Agent Folder
.qoder/*

View File

@@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000
# Files URL
FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints.
# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host.
# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network.
INTERNAL_FILES_URL=http://host.docker.internal:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
# Example: INTERNAL_FILES_URL=http://api:5001
INTERNAL_FILES_URL=http://127.0.0.1:5001
# TRIGGER URL
TRIGGER_URL=http://localhost:5001
@@ -180,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
COOKIE_DOMAIN=
# Vector database configuration
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`.
# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
# Prefix used to create collection name in vector database
VECTOR_INDEX_NAME_PREFIX=Vector_index
@@ -217,20 +217,6 @@ COUCHBASE_PASSWORD=password
COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default
# Hologres configuration
# access_key_id is used as the PG username, access_key_secret is used as the PG password
HOLOGRES_HOST=
HOLOGRES_PORT=80
HOLOGRES_DATABASE=
HOLOGRES_ACCESS_KEY_ID=
HOLOGRES_ACCESS_KEY_SECRET=
HOLOGRES_SCHEMA=public
HOLOGRES_TOKENIZER=jieba
HOLOGRES_DISTANCE_METHOD=Cosine
HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq
HOLOGRES_MAX_DEGREE=64
HOLOGRES_EF_CONSTRUCTION=400
# Milvus configuration
MILVUS_URI=http://127.0.0.1:19530
MILVUS_TOKEN=

View File

@@ -43,6 +43,7 @@ forbidden_modules =
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
@@ -89,6 +90,9 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
dify_graph.nodes.agent.agent_node -> core.model_manager
dify_graph.nodes.agent.agent_node -> core.provider_manager
dify_graph.nodes.agent.agent_node -> core.tools.tool_manager
dify_graph.nodes.llm.llm_utils -> core.model_manager
dify_graph.nodes.llm.protocols -> core.model_manager
dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -96,6 +100,9 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
dify_graph.nodes.tool.tool_node -> core.tools.tool_engine
dify_graph.nodes.tool.tool_node -> core.tools.tool_manager
dify_graph.nodes.agent.agent_node -> core.agent.entities
dify_graph.nodes.agent.agent_node -> core.agent.plugin_entities
dify_graph.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model
@@ -103,9 +110,12 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager
dify_graph.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util
@@ -114,12 +124,17 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities
dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util
dify_graph.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
dify_graph.nodes.llm.node -> models.dataset
dify_graph.nodes.agent.agent_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.file_saver -> core.tools.signature
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models
dify_graph.nodes.llm.node -> models.model
dify_graph.nodes.agent.agent_node -> services
dify_graph.nodes.tool.tool_node -> services
dify_graph.model_runtime.model_providers.__base.ai_model -> configs
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis

View File

@@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@@ -1,45 +1,16 @@
import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check.
# Defined at module level to avoid per-request tuple construction.
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - features: billing/plan features (ProviderContextProvider)
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
# - workspaces/current: workspace + model providers (AppContextProvider)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/features",
"/console/api/account/profile",
"/console/api/workspaces/current",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@@ -60,39 +31,6 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request

View File

@@ -160,7 +160,6 @@ def migrate_knowledge_vector_database():
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,
VectorType.HOLOGRES,
VectorType.CHROMA,
VectorType.MYSCALE,
VectorType.PGVECTO_RS,

View File

@@ -26,7 +26,6 @@ from .vdb.chroma_config import ChromaConfig
from .vdb.clickzetta_config import ClickzettaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.hologres_config import HologresConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.iris_config import IrisVectorConfig
from .vdb.lindorm_config import LindormConfig
@@ -348,7 +347,6 @@ class MiddlewareConfig(
AnalyticdbConfig,
ChromaConfig,
ClickzettaConfig,
HologresConfig,
HuaweiCloudConfig,
IrisVectorConfig,
MilvusConfig,

View File

@@ -1,68 +0,0 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
class HologresConfig(BaseSettings):
"""
Configuration settings for Hologres vector database.
Hologres is compatible with PostgreSQL protocol.
access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
HOLOGRES_HOST: str | None = Field(
description="Hostname or IP address of the Hologres instance.",
default=None,
)
HOLOGRES_PORT: int = Field(
description="Port number for connecting to the Hologres instance.",
default=80,
)
HOLOGRES_DATABASE: str | None = Field(
description="Name of the Hologres database to connect to.",
default=None,
)
HOLOGRES_ACCESS_KEY_ID: str | None = Field(
description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.",
default=None,
)
HOLOGRES_ACCESS_KEY_SECRET: str | None = Field(
description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.",
default=None,
)
HOLOGRES_SCHEMA: str = Field(
description="Schema name in the Hologres database.",
default="public",
)
HOLOGRES_TOKENIZER: TokenizerType = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)
HOLOGRES_MAX_DEGREE: int = Field(
description="Max degree (M) parameter for HNSW vector index.",
default=64,
)
HOLOGRES_EF_CONSTRUCTION: int = Field(
description="ef_construction parameter for HNSW vector index.",
default=400,
)

View File

@@ -25,8 +25,7 @@ from controllers.console.wraps import (
)
from core.ops.ops_trace_manager import OpsTraceManager
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.trigger.constants import TRIGGER_NODE_TYPES
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.enums import NodeType, WorkflowExecutionStatus
from dify_graph.file import helpers as file_helpers
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
@@ -509,7 +508,11 @@ class AppListApi(Resource):
.scalars()
.all()
)
trigger_node_types = TRIGGER_NODE_TYPES
trigger_node_types = {
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:

View File

@@ -1,4 +1,5 @@
import json
from enum import StrEnum
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
@@ -10,7 +11,6 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -19,6 +19,11 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
app_server_model = console_ns.model("AppServer", app_server_fields)
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict = Field(..., description="Server parameters configuration")
@@ -112,10 +117,9 @@ class AppMCPServerController(Resource):
server.parameters = json.dumps(payload.parameters, ensure_ascii=False)
if payload.status:
try:
server.status = AppMCPServerStatus(payload.status)
except ValueError:
if payload.status not in [status.value for status in AppMCPServerStatus]:
raise ValueError("Invalid status")
server.status = payload.status
db.session.commit()
return server

View File

@@ -22,7 +22,6 @@ from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.trace_id_helper import get_external_trace_id
from core.plugin.impl.exc import PluginInvokeError
from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE
from core.trigger.debug.event_selectors import (
TriggerDebugEvent,
TriggerDebugEventPoller,
@@ -1210,7 +1209,7 @@ class DraftWorkflowTriggerNodeApi(Resource):
node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config)
event: TriggerDebugEvent | None = None
# for schedule trigger, when run single node, just execute directly
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
if node_type == NodeType.TRIGGER_SCHEDULE:
event = TriggerDebugEvent(
workflow_args={},
node_id=node_id,

View File

@@ -263,7 +263,6 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool
VectorType.BAIDU,
VectorType.ALIBABACLOUD_MYSQL,
VectorType.IRIS,
VectorType.HOLOGRES,
}
semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}

View File

@@ -43,7 +43,6 @@ from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -216,7 +215,7 @@ class AccountInitApi(Resource):
db.session.query(InvitationCode)
.where(
InvitationCode.code == args.invitation_code,
InvitationCode.status == InvitationCodeStatus.UNUSED,
InvitationCode.status == "unused",
)
.first()
)
@@ -224,7 +223,7 @@ class AccountInitApi(Resource):
if not invitation_code:
raise InvalidInvitationCodeError()
invitation_code.status = InvitationCodeStatus.USED
invitation_code.status = "used"
invitation_code.used_at = naive_utc_now()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
@@ -232,7 +231,7 @@ class AccountInitApi(Resource):
account.interface_language = args.interface_language
account.timezone = args.timezone
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE
account.status = "active"
account.initialized_at = naive_utc_now()
db.session.commit()

View File

@@ -5,7 +5,6 @@ from typing import Any, Literal
from flask import request, send_file
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -170,20 +169,6 @@ register_enum_models(
)
def _read_upload_content(file: FileStorage, max_size: int) -> bytes:
"""
Read the uploaded file and validate its actual size before delegating to the plugin service.
FileStorage.content_length is not reliable for multipart test uploads and may be zero even when
content exists, so the controllers validate against the loaded bytes instead.
"""
content = file.read()
if len(content) > max_size:
raise ValueError("File size exceeds the maximum allowed size")
return content
@console_ns.route("/workspaces/current/plugin/debugging-key")
class PluginDebuggingKeyApi(Resource):
@setup_required
@@ -299,7 +284,12 @@ class PluginUploadFromPkgApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["pkg"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE)
# check file size
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_pkg(tenant_id, content)
except PluginDaemonClientSideError as e:
@@ -338,7 +328,12 @@ class PluginUploadFromBundleApi(Resource):
_, tenant_id = current_account_with_tenant()
file = request.files["bundle"]
content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE)
# check file size
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
raise ValueError("File size exceeds the maximum allowed size")
content = file.read()
try:
response = PluginService.upload_bundle(tenant_id, content)
except PluginDaemonClientSideError as e:

View File

@@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from dify_graph.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.enums import AppMCPServerStatus
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -3,7 +3,7 @@ import time
from collections.abc import Callable
from enum import StrEnum, auto
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
from typing import Concatenate, ParamSpec, TypeVar, cast
from flask import current_app, request
from flask_login import user_logged_in
@@ -44,22 +44,10 @@ class FetchUserArg(BaseModel):
required: bool = False
@overload
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
@overload
def validate_app_token(
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def validate_app_token(
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated_view(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("app")
app_model = db.session.query(App).where(App.id == api_token.app_id).first()
@@ -225,20 +213,10 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
return interceptor
@overload
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
@overload
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
def validate_dataset_token(
view: Callable[Concatenate[T, P], R] | None = None,
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
@wraps(view_func)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
def decorator(view: Callable[Concatenate[T, P], R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
api_token = validate_and_get_api_token("dataset")
# get url path dataset_id from positional args or kwargs
@@ -309,7 +287,7 @@ def validate_dataset_token(
raise Unauthorized("Tenant owner account does not exist.")
else:
raise Unauthorized("Tenant does not exist.")
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
return view(api_token.tenant_id, *args, **kwargs)
return decorated

View File

@@ -6,7 +6,6 @@ from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.errors import AgentMaxIterationError
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
@@ -23,6 +22,7 @@ from dify_graph.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

View File

@@ -1,9 +0,0 @@
class AgentMaxIterationError(Exception):
"""Raised when an agent runner exceeds the configured max iteration count."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

View File

@@ -5,7 +5,6 @@ from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.errors import AgentMaxIterationError
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
@@ -26,6 +25,7 @@ from dify_graph.model_runtime.entities import (
UserPromptMessage,
)
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from dify_graph.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)

View File

@@ -69,7 +69,7 @@ from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.enums import WorkflowExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory
from dify_graph.runtime import GraphRuntimeState
from dify_graph.system_variable import SystemVariable
@@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
# Record files if it's an answer node or end node
if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]:
if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]:
self._recorded_files.extend(
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)

View File

@@ -48,13 +48,12 @@ from core.app.entities.task_entities import (
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE
from core.trigger.trigger_manager import TriggerManager
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import (
BuiltinNodeTypes,
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
@@ -443,7 +442,7 @@ class WorkflowResponseConverter:
event: QueueNodeStartedEvent,
task_id: str,
) -> NodeStartStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
@@ -465,13 +464,13 @@ class WorkflowResponseConverter:
)
try:
if event.node_type == BuiltinNodeTypes.TOOL:
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=ToolProviderType(event.provider_type),
provider_id=event.provider_id,
)
elif event.node_type == BuiltinNodeTypes.DATASOURCE:
elif event.node_type == NodeType.DATASOURCE:
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(
self._application_generate_entity.app_config.tenant_id,
@@ -480,7 +479,7 @@ class WorkflowResponseConverter:
response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url(
self._application_generate_entity.app_config.tenant_id
)
elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE:
elif event.node_type == NodeType.TRIGGER_PLUGIN:
response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon(
self._application_generate_entity.app_config.tenant_id,
event.provider_id,
@@ -497,7 +496,7 @@ class WorkflowResponseConverter:
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
) -> NodeFinishStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
@@ -555,7 +554,7 @@ class WorkflowResponseConverter:
event: QueueNodeRetryEvent,
task_id: str,
) -> NodeRetryStreamResponse | None:
if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
@@ -613,7 +612,7 @@ class WorkflowResponseConverter:
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
@@ -636,7 +635,7 @@ class WorkflowResponseConverter:
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
created_at=int(time.time()),
@@ -663,7 +662,7 @@ class WorkflowResponseConverter:
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,
@@ -693,7 +692,7 @@ class WorkflowResponseConverter:
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
created_at=int(time.time()),
extras={},
@@ -716,7 +715,7 @@ class WorkflowResponseConverter:
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
index=event.index,
# The `pre_loop_output` field is not utilized by the frontend.
@@ -745,7 +744,7 @@ class WorkflowResponseConverter:
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type,
node_type=event.node_type.value,
title=event.node_title,
outputs=new_outputs,
outputs_truncated=outputs_truncated,

View File

@@ -12,7 +12,7 @@ from core.app.entities.app_invoke_entities import (
build_dify_run_context,
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities.graph_init_params import GraphInitParams
from dify_graph.enums import WorkflowType
@@ -274,8 +274,6 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if start_node_id is None:
start_node_id = get_default_root_node_id(graph_config)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id)
if not graph:

View File

@@ -3,10 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.queue_entities import (
AppQueueEvent,
@@ -32,8 +29,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.node_factory import DifyNodeFactory
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
@@ -67,6 +63,7 @@ from dify_graph.graph_events import (
NodeRunSucceededEvent,
)
from dify_graph.graph_events.graph import GraphRunAbortedEvent
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
@@ -140,9 +137,6 @@ class WorkflowBasedAppRunner:
graph_runtime_state=graph_runtime_state,
)
if root_node_id is None:
root_node_id = get_default_root_node_id(graph_config)
# init graph
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id)
@@ -314,7 +308,7 @@ class WorkflowBasedAppRunner:
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# Use the variable pool from graph_runtime_state instead of creating a new one
variable_pool = graph_runtime_state.variable_pool
@@ -342,18 +336,6 @@ class WorkflowBasedAppRunner:
return graph, variable_pool
@staticmethod
def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None:
raw_agent_strategy = event.extras.get("agent_strategy")
if raw_agent_strategy is None:
return None
try:
return AgentStrategyInfo.model_validate(raw_agent_strategy)
except ValidationError:
logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True)
return None
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent):
"""
Handle event
@@ -439,7 +421,7 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=self._build_agent_strategy_info(event),
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
)
@@ -508,9 +490,7 @@ class WorkflowBasedAppRunner:
elif isinstance(event, NodeRunRetrieverResourceEvent):
self._publish_event(
QueueRetrieverResourcesEvent(
retriever_resources=[
RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources
],
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
)

View File

@@ -1,3 +0,0 @@
from .agent_strategy import AgentStrategyInfo
__all__ = ["AgentStrategyInfo"]

View File

@@ -1,8 +0,0 @@
from pydantic import BaseModel, ConfigDict
class AgentStrategyInfo(BaseModel):
name: str
icon: str | None = None
model_config = ConfigDict(extra="forbid")

View File

@@ -5,12 +5,13 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from dify_graph.nodes import NodeType
class QueueEvent(StrEnum):
@@ -313,7 +314,7 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
agent_strategy: AgentStrategyInfo | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType

View File

@@ -4,8 +4,8 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage
@@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None
loop_id: str | None = None
agent_strategy: AgentStrategyInfo | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str

View File

@@ -2,7 +2,7 @@ import logging
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
from dify_graph.conversation_variable_updater import ConversationVariableUpdater
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
@@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, NodeRunSucceededEvent):
return
if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER:
if event.node_type != NodeType.VARIABLE_ASSIGNER:
return
if self.graph_runtime_state is None:
return

View File

@@ -12,7 +12,7 @@ from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase
@@ -113,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer):
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case BuiltinNodeTypes.LLM:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case BuiltinNodeTypes.PARAMETER_EXTRACTOR:
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case BuiltinNodeTypes.QUESTION_CLASSIFIER:
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None

View File

@@ -16,7 +16,7 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_
from typing_extensions import override
from configs import dify_config
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_events import GraphNodeEventBase
from dify_graph.nodes.base.node import Node
@@ -74,13 +74,16 @@ class ObservabilityLayer(GraphEngineLayer):
def _build_parser_registry(self) -> None:
"""Initialize parser registry for node types."""
self._parsers = {
BuiltinNodeTypes.TOOL: ToolNodeOTelParser(),
BuiltinNodeTypes.LLM: LLMNodeOTelParser(),
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
NodeType.TOOL: ToolNodeOTelParser(),
NodeType.LLM: LLMNodeOTelParser(),
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
return self._parsers.get(node.node_type, self._default_parser)
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
return self._parsers.get(node_type, self._default_parser)
return self._default_parser
@override
def on_graph_start(self) -> None:

View File

@@ -12,7 +12,6 @@ from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.enums import CreatorUserRole
_logger = logging.getLogger(__name__)
@@ -39,9 +38,7 @@ class DatasetIndexToolCallbackHandler:
source="app",
source_app_id=self._app_id,
created_by_role=(
CreatorUserRole.ACCOUNT
if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatorUserRole.END_USER
"account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"
),
created_by=self._user_id,
)

View File

@@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from dify_graph.file import File
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from dify_graph.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile

View File

@@ -58,7 +58,7 @@ from core.ops.entities.trace_entity import (
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import WorkflowNodeExecutionTriggeredFrom
@@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance):
self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata
):
try:
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata)
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
elif node_execution.node_type == NodeType.TOOL:
node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata)
else:
node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata)

View File

@@ -155,8 +155,8 @@ def wrap_span_metadata(metadata, **kwargs):
return metadata
# Mapping from built-in node type strings to OpenInference span kinds.
# Node types not listed here default to CHAIN.
# Mapping from NodeType string values to OpenInference span kinds.
# NodeType values not listed here default to CHAIN.
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
@@ -168,7 +168,7 @@ _NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type.
Covers every built-in node type string. Nodes that do not have a
Covers every ``NodeType`` enum value. Nodes that do not have a
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
"""

View File

@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
)
from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
@@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
)
from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}
@@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance):
"ls_model_name": process_data.get("model_name", ""),
}
)
elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
run_type = LangSmithRunType.retriever
else:
run_type = LangSmithRunType.tool

View File

@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance):
"app_name": node.title,
}
if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER):
if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER):
inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node)
attributes.update(llm_attributes)
elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST:
elif node.node_type == NodeType.HTTP_REQUEST:
inputs = node.process_data # contains request URL
if not inputs:
@@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == BuiltinNodeTypes.LLM:
elif node.node_type == NodeType.LLM:
outputs = outputs.get("text", outputs)
node_span.end(
outputs=outputs,
@@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance):
def _get_node_span_type(self, node_type: str) -> str:
"""Map Dify node types to MLflow span types"""
node_type_mapping = {
BuiltinNodeTypes.LLM: SpanType.LLM,
BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
BuiltinNodeTypes.TOOL: SpanType.TOOL,
BuiltinNodeTypes.CODE: SpanType.TOOL,
BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL,
BuiltinNodeTypes.AGENT: SpanType.AGENT,
NodeType.LLM: SpanType.LLM,
NodeType.QUESTION_CLASSIFIER: SpanType.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER,
NodeType.TOOL: SpanType.TOOL,
NodeType.CODE: SpanType.TOOL,
NodeType.HTTP_REQUEST: SpanType.TOOL,
NodeType.AGENT: SpanType.AGENT,
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]

View File

@@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -187,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -628,10 +628,10 @@ class TraceTask:
if not message_data:
return {}
conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_modes = db.session.scalars(conversation_mode_stmt).all()
if not conversation_modes or len(conversation_modes) == 0:
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
conversation_mode = conversation_modes[0]
conversation_mode = conversation_mode[0]
created_at = message_data.created_at
inputs = message_data.message

View File

@@ -27,7 +27,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from dify_graph.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from extensions.ext_database import db
from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom
@@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance):
if node_span:
self.trace_client.add_span(node_span)
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
self._record_llm_metrics(node_execution)
except Exception:
logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id)
@@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance):
) -> SpanData | None:
"""Build span for different node types"""
try:
if node_execution.node_type == BuiltinNodeTypes.LLM:
if node_execution.node_type == NodeType.LLM:
return TencentSpanBuilder.build_workflow_llm_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return TencentSpanBuilder.build_workflow_retrieval_span(
trace_id, workflow_span_id, trace_info, node_execution
)
elif node_execution.node_type == BuiltinNodeTypes.TOOL:
elif node_execution.node_type == NodeType.TOOL:
return TencentSpanBuilder.build_workflow_tool_span(
trace_id, workflow_span_id, trace_info, node_execution
)

View File

@@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import (
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
@@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance):
node_name = node_execution.title
node_type = node_execution.node_type
status = node_execution.status
if node_type == BuiltinNodeTypes.LLM:
if node_type == NodeType.LLM:
inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {}
else:
inputs = node_execution.inputs or {}

View File

@@ -1,5 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
@@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,

View File

@@ -305,7 +305,9 @@ class ProviderManager:
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
if available_models:
available_model = available_models[0]
available_model = next(
(model for model in available_models if model.model == "gpt-4"), available_models[0]
)
default_model = TenantDefaultModel(
tenant_id=tenant_id,
@@ -625,7 +627,7 @@ class ProviderManager:
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM,
provider_type=ProviderType.SYSTEM.value,
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,

View File

@@ -1,361 +0,0 @@
import json
import logging
import time
from typing import Any
import holo_search_sdk as holo # type: ignore
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from psycopg import sql as psql
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HologresVectorConfig(BaseModel):
"""
Configuration for Hologres vector database connection.
In Hologres, access_key_id is used as the PostgreSQL username,
and access_key_secret is used as the PostgreSQL password.
"""
host: str
port: int = 80
database: str
access_key_id: str
access_key_secret: str
schema_name: str = "public"
tokenizer: TokenizerType = "jieba"
distance_method: DistanceType = "Cosine"
base_quantization_type: BaseQuantizationType = "rabitq"
max_degree: int = 64
ef_construction: int = 400
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict):
if not values.get("host"):
raise ValueError("config HOLOGRES_HOST is required")
if not values.get("database"):
raise ValueError("config HOLOGRES_DATABASE is required")
if not values.get("access_key_id"):
raise ValueError("config HOLOGRES_ACCESS_KEY_ID is required")
if not values.get("access_key_secret"):
raise ValueError("config HOLOGRES_ACCESS_KEY_SECRET is required")
return values
class HologresVector(BaseVector):
"""
Hologres vector storage implementation using holo-search-sdk.
Supports semantic search (vector), full-text search, and hybrid search.
"""
def __init__(self, collection_name: str, config: HologresVectorConfig):
super().__init__(collection_name)
self._config = config
self._client = self._init_client(config)
self.table_name = f"embedding_{collection_name}".lower()
def _init_client(self, config: HologresVectorConfig):
"""Initialize and return a holo-search-sdk client."""
client = holo.connect(
host=config.host,
port=config.port,
database=config.database,
access_key_id=config.access_key_id,
access_key_secret=config.access_key_secret,
schema=config.schema_name,
)
client.connect()
return client
def get_type(self) -> str:
return VectorType.HOLOGRES
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""Create collection table with vector and full-text indexes, then add texts."""
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""Add texts with embeddings to the collection using batch upsert."""
if not documents:
return []
pks: list[str] = []
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
values = []
column_names = ["id", "text", "meta", "embedding"]
for j, doc in enumerate(batch_docs):
doc_id = doc.metadata.get("doc_id", "") if doc.metadata else ""
pks.append(doc_id)
values.append(
[
doc_id,
doc.page_content,
json.dumps(doc.metadata or {}),
batch_embeddings[j],
]
)
table = self._client.open_table(self.table_name)
table.upsert_multi(
index_column="id",
values=values,
column_names=column_names,
update=True,
update_columns=["text", "meta", "embedding"],
)
return pks
def text_exists(self, id: str) -> bool:
"""Check if a text with the given doc_id exists in the collection."""
if not self._client.check_table_exist(self.table_name):
return False
result = self._client.execute(
psql.SQL("SELECT 1 FROM {} WHERE id = {} LIMIT 1").format(
psql.Identifier(self.table_name), psql.Literal(id)
),
fetch_result=True,
)
return bool(result)
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
"""Get document IDs by metadata field key and value."""
result = self._client.execute(
psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
),
fetch_result=True,
)
if result:
return [row[0] for row in result]
return None
def delete_by_ids(self, ids: list[str]):
"""Delete documents by their doc_id list."""
if not ids:
return
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE id IN ({})").format(
psql.Identifier(self.table_name),
psql.SQL(", ").join(psql.Literal(id) for id in ids),
)
)
def delete_by_metadata_field(self, key: str, value: str):
"""Delete documents by metadata field key and value."""
if not self._client.check_table_exist(self.table_name):
return
self._client.execute(
psql.SQL("DELETE FROM {} WHERE meta->>{} = {}").format(
psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value)
)
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Search for documents by vector similarity."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
table = self._client.open_table(self.table_name)
query = (
table.search_vector(
vector=query_vector,
column="embedding",
distance_method=self._config.distance_method,
output_name="distance",
)
.select(["id", "text", "meta"])
.limit(top_k)
)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
query = query.where(filter_sql)
results = query.fetchall()
return self._process_vector_results(results, score_threshold)
def _process_vector_results(self, results: list, score_threshold: float) -> list[Document]:
"""Process vector search results into Document objects."""
docs = []
for row in results:
# row format: (distance, id, text, meta)
# distance is first because search_vector() adds the computed column before selected columns
distance = row[0]
text = row[2]
meta = row[3]
if isinstance(meta, str):
meta = json.loads(meta)
# Convert distance to similarity score (consistent with pgvector)
score = 1 - distance
meta["score"] = score
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=meta))
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Search for documents by full-text search."""
if not self._client.check_table_exist(self.table_name):
return []
top_k = kwargs.get("top_k", 4)
table = self._client.open_table(self.table_name)
search_query = table.search_text(
column="text",
expression=query,
return_score=True,
return_score_name="score",
return_all_columns=True,
).limit(top_k)
# Apply document_ids_filter if provided
document_ids_filter = kwargs.get("document_ids_filter")
if document_ids_filter:
filter_sql = psql.SQL("meta->>'document_id' IN ({})").format(
psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter)
)
search_query = search_query.where(filter_sql)
results = search_query.fetchall()
return self._process_full_text_results(results)
def _process_full_text_results(self, results: list) -> list[Document]:
"""Process full-text search results into Document objects."""
docs = []
for row in results:
# row format: (id, text, meta, embedding, score)
text = row[1]
meta = row[2]
score = row[-1] # score is the last column from return_score
if isinstance(meta, str):
meta = json.loads(meta)
meta["score"] = score
docs.append(Document(page_content=text, metadata=meta))
return docs
def delete(self):
"""Delete the entire collection table."""
if self._client.check_table_exist(self.table_name):
self._client.drop_table(self.table_name)
def _create_collection(self, dimension: int):
"""Create the collection table with vector and full-text indexes."""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return
if not self._client.check_table_exist(self.table_name):
# Create table via SQL with CHECK constraint for vector dimension
create_table_sql = psql.SQL("""
CREATE TABLE IF NOT EXISTS {} (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
meta JSONB NOT NULL,
embedding float4[] NOT NULL
CHECK (array_ndims(embedding) = 1
AND array_length(embedding, 1) = {})
);
""").format(psql.Identifier(self.table_name), psql.Literal(dimension))
self._client.execute(create_table_sql)
# Wait for table to be fully ready before creating indexes
max_wait_seconds = 30
poll_interval = 2
for _ in range(max_wait_seconds // poll_interval):
if self._client.check_table_exist(self.table_name):
break
time.sleep(poll_interval)
else:
raise RuntimeError(f"Table {self.table_name} was not ready after {max_wait_seconds}s")
# Open table and set vector index
table = self._client.open_table(self.table_name)
table.set_vector_index(
column="embedding",
distance_method=self._config.distance_method,
base_quantization_type=self._config.base_quantization_type,
max_degree=self._config.max_degree,
ef_construction=self._config.ef_construction,
use_reorder=self._config.base_quantization_type == "rabitq",
)
# Create full-text search index
table.create_text_index(
index_name=f"ft_idx_{self._collection_name}",
column="text",
tokenizer=self._config.tokenizer,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
class HologresVectorFactory(AbstractVectorFactory):
"""Factory class for creating HologresVector instances."""
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HOLOGRES, collection_name))
return HologresVector(
collection_name=collection_name,
config=HologresVectorConfig(
host=dify_config.HOLOGRES_HOST or "",
port=dify_config.HOLOGRES_PORT,
database=dify_config.HOLOGRES_DATABASE or "",
access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "",
access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "",
schema_name=dify_config.HOLOGRES_SCHEMA,
tokenizer=dify_config.HOLOGRES_TOKENIZER,
distance_method=dify_config.HOLOGRES_DISTANCE_METHOD,
base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE,
max_degree=dify_config.HOLOGRES_MAX_DEGREE,
ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION,
),
)

View File

@@ -135,8 +135,8 @@ class PGVectoRS(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self._client) as session:
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>:key = :value")
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ")
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
@@ -172,9 +172,9 @@ class PGVectoRS(BaseVector):
def text_exists(self, id: str) -> bool:
with Session(self._client) as session:
select_statement = sql_text(
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = :doc_id limit 1"
f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; "
)
result = session.execute(select_statement, {"doc_id": id}).fetchall()
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@@ -154,8 +154,10 @@ class RelytVector(BaseVector):
def get_ids_by_metadata_field(self, key: str, value: str):
result = None
with Session(self.client) as session:
select_statement = sql_text(f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>:key = :value""")
result = session.execute(select_statement, {"key": key, "value": value}).fetchall()
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """
)
result = session.execute(select_statement).fetchall()
if result:
return [item[0] for item in result]
else:
@@ -199,10 +201,11 @@ class RelytVector(BaseVector):
def delete_by_ids(self, ids: list[str]):
with Session(self.client) as session:
ids_str = ",".join(f"'{doc_id}'" for doc_id in ids)
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = ANY(:doc_ids)"""
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """
)
result = session.execute(select_statement, {"doc_ids": ids}).fetchall()
result = session.execute(select_statement).fetchall()
if result:
ids = [item[0] for item in result]
self.delete_by_uuids(ids)
@@ -215,9 +218,9 @@ class RelytVector(BaseVector):
def text_exists(self, id: str) -> bool:
with Session(self.client) as session:
select_statement = sql_text(
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = :doc_id limit 1"""
f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """
)
result = session.execute(select_statement, {"doc_id": id}).fetchall()
result = session.execute(select_statement).fetchall()
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:

View File

@@ -38,7 +38,7 @@ class AbstractVectorFactory(ABC):
class Vector:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
@@ -191,10 +191,6 @@ class Vector:
from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory
return IrisVectorFactory
case VectorType.HOLOGRES:
from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory
return HologresVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

View File

@@ -34,4 +34,3 @@ class VectorType(StrEnum):
MATRIXONE = "matrixone"
CLICKZETTA = "clickzetta"
IRIS = "iris"
HOLOGRES = "hologres"

View File

@@ -196,7 +196,6 @@ class WeaviateVector(BaseVector):
),
wc.Property(name="document_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_id", data_type=wc.DataType.TEXT),
wc.Property(name="doc_type", data_type=wc.DataType.TEXT),
wc.Property(name="chunk_index", data_type=wc.DataType.INT),
],
vector_config=wc.Configure.Vectors.self_provided(),
@@ -226,8 +225,6 @@ class WeaviateVector(BaseVector):
to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT))
if "doc_id" not in existing:
to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT))
if "doc_type" not in existing:
to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT))
if "chunk_index" not in existing:
to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT))

View File

@@ -9,8 +9,8 @@ from flask import current_app
from sqlalchemy import delete, func, select
from core.db.session_factory import session_factory
from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview
from dify_graph.nodes.knowledge_index.exc import KnowledgeIndexNodeError
from dify_graph.repositories.index_processor_protocol import Preview, PreviewItem, QaPreview
from models.dataset import Dataset, Document, DocumentSegment
from .index_processor_factory import IndexProcessorFactory

View File

@@ -294,7 +294,7 @@ class BaseIndexProcessor(ABC):
logging.warning("Error downloading image from %s: %s", image_url, str(e))
return None
except Exception:
logging.warning("Unexpected error downloading image from %s", image_url, exc_info=True)
logging.exception("Unexpected error downloading image from %s", image_url)
return None
def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None:

View File

@@ -56,18 +56,18 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.nodes.knowledge_retrieval.retrieval import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage
from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.knowledge_retrieval import exc
from dify_graph.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -83,7 +83,6 @@ from models.dataset import (
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from models.enums import CreatorUserRole
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
@@ -1010,7 +1009,7 @@ class DatasetRetrieval:
content=json.dumps(contents),
source="app",
source_app_id=app_id,
created_by_role=CreatorUserRole(user_from),
created_by_role=user_from,
created_by=user_id,
)
dataset_queries.append(dataset_query)

View File

@@ -146,9 +146,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# No sequence number generation needed anymore
from models.workflow import WorkflowType as ModelWorkflowType
db_model.type = ModelWorkflowType(domain_model.workflow_type.value)
db_model.type = domain_model.workflow_type
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None

View File

@@ -18,7 +18,7 @@ from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_att
from configs import dify_config
from dify_graph.entities import WorkflowNodeExecution
from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter
@@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
index=db_model.index,
predecessor_node_id=db_model.predecessor_node_id,
node_id=db_model.node_id,
node_type=db_model.node_type,
node_type=NodeType(db_model.node_type),
title=db_model.title,
inputs=inputs,
process_data=process_data,

View File

@@ -116,7 +116,6 @@ class ToolParameterConfigurationManager:
return a deep copy of parameters with decrypted values
"""
parameters = self._deep_copy(parameters)
cache = ToolParameterCache(
tenant_id=self.tenant_id,

View File

@@ -3,7 +3,7 @@ from typing import Any
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
from dify_graph.nodes.base.entities import OutputVariableEntity
from dify_graph.variables.input_entities import VariableEntity
@@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils:
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod

View File

@@ -1,18 +0,0 @@
from typing import Final
TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook"
TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_INFO_METADATA_KEY: Final[str] = "trigger_info"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
{
TRIGGER_WEBHOOK_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_PLUGIN_NODE_TYPE,
}
)
def is_trigger_node_type(node_type: str) -> bool:
return node_type in TRIGGER_NODE_TYPES

View File

@@ -11,11 +11,6 @@ from typing import Any
from pydantic import BaseModel
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.trigger.constants import (
TRIGGER_PLUGIN_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_WEBHOOK_NODE_TYPE,
)
from core.trigger.debug.event_bus import TriggerDebugEventBus
from core.trigger.debug.events import (
PluginTriggerDebugEvent,
@@ -24,9 +19,10 @@ from core.trigger.debug.events import (
build_plugin_pool_key,
build_webhook_pool_key,
)
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import NodeType
from dify_graph.nodes.trigger_plugin.entities import TriggerEventNodeData
from dify_graph.nodes.trigger_schedule.entities import ScheduleConfig
from extensions.ext_redis import redis_client
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
from libs.schedule_utils import calculate_next_run_at
@@ -210,19 +206,21 @@ def create_event_poller(
if not node_config:
raise ValueError("Node data not found for node %s", node_id)
node_type = draft_workflow.get_node_type_from_node_config(node_config)
if node_type == TRIGGER_PLUGIN_NODE_TYPE:
return PluginTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
if node_type == TRIGGER_WEBHOOK_NODE_TYPE:
return WebhookTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
if node_type == TRIGGER_SCHEDULE_NODE_TYPE:
return ScheduleTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
raise ValueError("unable to create event poller for node type %s", node_type)
match node_type:
case NodeType.TRIGGER_PLUGIN:
return PluginTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_WEBHOOK:
return WebhookTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case NodeType.TRIGGER_SCHEDULE:
return ScheduleTriggerDebugEventPoller(
tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id
)
case _:
raise ValueError("unable to create event poller for node type %s", node_type)
def select_trigger_debug_events(

View File

@@ -1 +1,4 @@
"""Core workflow package."""
from .node_factory import DifyNodeFactory
from .workflow_entry import WorkflowEntry
__all__ = ["DifyNodeFactory", "WorkflowEntry"]

View File

@@ -1,7 +1,4 @@
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from functools import lru_cache
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, TypeAlias, cast, final
from sqlalchemy import select
@@ -11,6 +8,7 @@ from typing_extensions import override
from configs import dify_config
from core.app.entities.app_invoke_entities import DifyRunContext
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import (
CodeExecutionError,
CodeExecutor,
@@ -19,19 +17,15 @@ from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.index_processor.index_processor import IndexProcessor
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.summary_index.summary_index import SummaryIndex
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.tools.tool_file_manager import ToolFileManager
from core.trigger.constants import TRIGGER_NODE_TYPES
from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer
from core.workflow.nodes.agent.plugin_strategy_adapter import (
PluginAgentStrategyPresentationProvider,
PluginAgentStrategyResolver,
)
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey
from dify_graph.enums import NodeType, SystemVariableKey
from dify_graph.file.file_manager import file_manager
from dify_graph.graph.graph import NodeFactory
from dify_graph.model_runtime.entities.model_entities import ModelType
@@ -45,7 +39,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
@@ -59,135 +53,6 @@ if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
LATEST_VERSION = "latest"
_START_NODE_TYPES: frozenset[NodeType] = frozenset(
(BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES)
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
if module_name in excluded_modules:
continue
importlib.import_module(module_name)
@lru_cache(maxsize=1)
def register_nodes() -> None:
"""Import production node modules so they self-register with ``Node``."""
_import_node_package("dify_graph.nodes")
_import_node_package("core.workflow.nodes")
def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]:
"""Return a read-only snapshot of the current production node registry.
The workflow layer owns node bootstrap because it must compose built-in
`dify_graph.nodes.*` implementations with workflow-local nodes under
`core.workflow.nodes.*`. Keeping this import side effect here avoids
reintroducing registry bootstrapping into lower-level graph primitives.
"""
register_nodes()
return Node.get_node_type_classes_mapping()
def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
node_mapping = get_node_type_classes_mapping().get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
def is_start_node_type(node_type: NodeType) -> bool:
"""Return True when the node type can serve as a workflow entry point."""
return node_type in _START_NODE_TYPES
def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str:
"""Resolve the default entry node for a persisted top-level workflow graph.
This workflow-layer helper depends on start-node semantics defined by
`is_start_node_type`, so it intentionally lives next to the node registry
instead of in the raw `dify_graph.entities.graph_config` schema module.
"""
nodes = graph_config.get("nodes")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
for node in nodes:
if not isinstance(node, Mapping):
continue
if node.get("type") == "custom-note":
continue
node_id = node.get("id")
data = node.get("data")
if not isinstance(node_id, str) or not isinstance(data, Mapping):
continue
node_type = data.get("type")
if isinstance(node_type, str) and is_start_node_type(node_type):
return node_id
raise ValueError("Unable to determine default root node ID from workflow graph")
class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]):
"""Mutable dict-like view over the current node registry."""
def __init__(self) -> None:
self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {}
self._cached_version = -1
self._deleted: set[NodeType] = set()
self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {}
def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]:
current_version = Node.get_registry_version()
if self._cached_version != current_version:
self._cached_snapshot = dict(get_node_type_classes_mapping())
self._cached_version = current_version
if not self._deleted and not self._overrides:
return self._cached_snapshot
snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted}
snapshot.update(self._overrides)
return snapshot
def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]:
return self._snapshot()[key]
def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None:
self._deleted.discard(key)
self._overrides[key] = value
def __delitem__(self, key: NodeType) -> None:
if key in self._overrides:
del self._overrides[key]
return
if key in self._cached_snapshot:
self._deleted.add(key)
return
raise KeyError(key)
def __iter__(self) -> Iterator[NodeType]:
return iter(self._snapshot())
def __len__(self) -> int:
return len(self._snapshot())
# Keep the canonical node-class mapping in the workflow layer that also bootstraps
# legacy `core.workflow.nodes.*` registrations.
NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping()
LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
@@ -229,20 +94,13 @@ class DefaultWorkflowCodeExecutor:
return isinstance(error, CodeExecutionError)
class DefaultLLMTemplateRenderer(TemplateRenderer):
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=inputs,
)
return str(result.get("result", ""))
@final
class DifyNodeFactory(NodeFactory):
"""
Default implementation of NodeFactory that resolves node classes from the live registry.
Default implementation of NodeFactory that uses the traditional node mapping.
This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING
and instantiating the appropriate node class.
"""
def __init__(
@@ -265,11 +123,11 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
self._http_request_file_manager = file_manager
self._rag_retrieval = DatasetRetrieval()
self._document_extractor_unstructured_api_config = UnstructuredApiConfig(
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY or "",
@@ -285,10 +143,6 @@ class DifyNodeFactory(NodeFactory):
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
self._agent_strategy_resolver = PluginAgentStrategyResolver()
self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider()
self._agent_runtime_support = AgentRuntimeSupport()
self._agent_message_transformer = AgentMessageTransformer()
@staticmethod
def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext:
@@ -316,51 +170,55 @@ class DifyNodeFactory(NodeFactory):
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
node_type = node_data.type
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
BuiltinNodeTypes.CODE: lambda: {
NodeType.CODE: lambda: {
"code_executor": self._code_executor,
"code_limits": self._code_limits,
},
BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: {
NodeType.TEMPLATE_TRANSFORM: lambda: {
"template_renderer": self._template_renderer,
"max_output_length": self._template_transform_max_output_length,
},
BuiltinNodeTypes.HTTP_REQUEST: lambda: {
NodeType.HTTP_REQUEST: lambda: {
"http_request_config": self._http_request_config,
"http_client": self._http_request_http_client,
"tool_file_manager_factory": self._http_request_tool_file_manager_factory,
"file_manager": self._http_request_file_manager,
},
BuiltinNodeTypes.HUMAN_INPUT: lambda: {
NodeType.HUMAN_INPUT: lambda: {
"form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id),
},
BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.KNOWLEDGE_INDEX: lambda: {
"index_processor": IndexProcessor(),
"summary_index_service": SummaryIndex(),
},
NodeType.LLM: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: {
NodeType.DATASOURCE: lambda: {
"datasource_manager": DatasourceManager,
},
NodeType.KNOWLEDGE_RETRIEVAL: lambda: {
"rag_retrieval": self._rag_retrieval,
},
NodeType.DOCUMENT_EXTRACTOR: lambda: {
"unstructured_api_config": self._document_extractor_unstructured_api_config,
"http_client": self._http_request_http_client,
},
BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=True,
),
BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
NodeType.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs(
node_class=node_class,
node_data=node_data,
include_http_client=False,
),
BuiltinNodeTypes.TOOL: lambda: {
NodeType.TOOL: lambda: {
"tool_file_manager_factory": self._http_request_tool_file_manager_factory(),
},
BuiltinNodeTypes.AGENT: lambda: {
"strategy_resolver": self._agent_strategy_resolver,
"presentation_provider": self._agent_strategy_presentation_provider,
"runtime_support": self._agent_runtime_support,
"message_transformer": self._agent_message_transformer,
},
}
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
return node_class(
@@ -380,7 +238,16 @@ class DifyNodeFactory(NodeFactory):
@staticmethod
def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]:
return resolve_workflow_node_class(node_type=node_type, node_version=node_version)
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
raise ValueError(f"No class mapping found for node type: {node_type}")
latest_node_class = node_mapping.get(LATEST_VERSION)
matched_node_class = node_mapping.get(node_version)
node_class = matched_node_class or latest_node_class
if not node_class:
raise ValueError(f"No latest version class found for node type: {node_type}")
return node_class
def _build_llm_compatible_node_init_kwargs(
self,
@@ -403,8 +270,6 @@ class DifyNodeFactory(NodeFactory):
model_instance=model_instance,
),
}
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
node_init_kwargs["template_renderer"] = self._llm_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs

View File

@@ -1 +0,0 @@
"""Workflow node implementations that remain under the legacy core.workflow namespace."""

View File

@@ -1,4 +0,0 @@
from .agent_node import AgentNode
from .entities import AgentNodeData
__all__ = ["AgentNode", "AgentNodeData"]

View File

@@ -1,188 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus
from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from .entities import AgentNodeData
from .exceptions import (
AgentInvocationError,
AgentMessageTransformError,
)
from .message_transformer import AgentMessageTransformer
from .runtime_support import AgentRuntimeSupport
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
from dify_graph.runtime import GraphRuntimeState
class AgentNode(Node[AgentNodeData]):
node_type = BuiltinNodeTypes.AGENT
_strategy_resolver: AgentStrategyResolver
_presentation_provider: AgentStrategyPresentationProvider
_runtime_support: AgentRuntimeSupport
_message_transformer: AgentMessageTransformer
def __init__(
self,
id: str,
config: NodeConfigDict,
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
strategy_resolver: AgentStrategyResolver,
presentation_provider: AgentStrategyPresentationProvider,
runtime_support: AgentRuntimeSupport,
message_transformer: AgentMessageTransformer,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._strategy_resolver = strategy_resolver
self._presentation_provider = presentation_provider
self._runtime_support = runtime_support
self._message_transformer = message_transformer
@classmethod
def version(cls) -> str:
return "1"
def populate_start_event(self, event) -> None:
dify_ctx = self.require_dify_context()
event.extras["agent_strategy"] = {
"name": self.node_data.agent_strategy_name,
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
}
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = self._strategy_resolver.resolve(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
parameters = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
)
parameters_for_log = self._runtime_support.build_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
tenant_id=dify_ctx.tenant_id,
app_id=dify_ctx.app_id,
invoke_from=dify_ctx.invoke_from,
for_log=True,
)
credentials = self._runtime_support.build_credentials(parameters=parameters)
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._message_transformer.transform(
messages=message_stream,
tool_info={
"icon": self._presentation_provider.get_icon(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
),
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result

View File

@@ -1,292 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Mapping
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.variables.segments import ArrayFileSegment
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError
class AgentMessageTransformer:
def transform(
self,
*,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == BuiltinNodeTypes.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
for log in agent_logs:
if log.message_id == agent_log.message_id:
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
json_output: list[dict[str, Any] | list[Any]] = []
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -1,40 +0,0 @@
from __future__ import annotations
from factories.agent_factory import get_plugin_agent_strategy
from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy
class PluginAgentStrategyResolver(AgentStrategyResolver):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy:
return get_plugin_agent_strategy(
tenant_id=tenant_id,
agent_strategy_provider_name=agent_strategy_provider_name,
agent_strategy_name=agent_strategy_name,
)
class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None:
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
try:
plugins = manager.list_plugins(tenant_id)
except Exception:
return None
try:
current_plugin = next(
plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name
)
except StopIteration:
return None
return current_plugin.declaration.icon

View File

@@ -1,276 +0,0 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
from dify_graph.enums import SystemVariableKey
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
class AgentRuntimeSupport:
def build_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
for_log: bool = False,
) -> dict[str, Any]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id,
app_id,
entity,
invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
memory = self.fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
model_instance=model_instance,
)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
return credentials
def fetch_memory(
self,
*,
variable_pool: VariablePool,
app_id: str,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
@staticmethod
def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]:
try:
AgentOldVersionModelFeatures(feature.value)
except ValueError:
model_schema.features.remove(feature)
return model_schema
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]

View File

@@ -1,39 +0,0 @@
from __future__ import annotations
from collections.abc import Generator, Sequence
from typing import Any, Protocol
from core.agent.plugin_entities import AgentStrategyParameter
from core.plugin.entities.request import InvokeCredentials
from core.tools.entities.tool_entities import ToolInvokeMessage
class ResolvedAgentStrategy(Protocol):
meta_version: str | None
def get_parameters(self) -> Sequence[AgentStrategyParameter]: ...
def invoke(
self,
*,
params: dict[str, Any],
user_id: str,
conversation_id: str | None = None,
app_id: str | None = None,
message_id: str | None = None,
credentials: InvokeCredentials | None = None,
) -> Generator[ToolInvokeMessage, None, None]: ...
class AgentStrategyResolver(Protocol):
def resolve(
self,
*,
tenant_id: str,
agent_strategy_provider_name: str,
agent_strategy_name: str,
) -> ResolvedAgentStrategy: ...
class AgentStrategyPresentationProvider(Protocol):
def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ...

View File

@@ -1 +0,0 @@
"""Datasource workflow node package."""

View File

@@ -1,5 +0,0 @@
"""Knowledge index workflow node package."""
KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index"
__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"]

View File

@@ -1 +0,0 @@
"""Knowledge retrieval workflow node package."""

View File

@@ -1,3 +0,0 @@
from .trigger_schedule_node import TriggerScheduleNode
__all__ = ["TriggerScheduleNode"]

View File

@@ -8,7 +8,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class
from core.workflow.node_factory import DifyNodeFactory
from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
@@ -21,8 +21,9 @@ from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLay
from dify_graph.graph_engine.layers.base import GraphEngineLayer
from dify_graph.graph_engine.protocols.command_channel import CommandChannel
from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from dify_graph.nodes import BuiltinNodeTypes
from dify_graph.nodes import NodeType
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
@@ -252,7 +253,7 @@ class WorkflowEntry:
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
if node_type != BuiltinNodeTypes.DATASOURCE:
if node_type != NodeType.DATASOURCE:
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@@ -302,7 +303,7 @@ class WorkflowEntry:
"height": node_height,
"type": "custom",
"data": {
"type": BuiltinNodeTypes.START,
"type": NodeType.START,
"title": "Start",
"desc": "Start",
},
@@ -338,11 +339,11 @@ class WorkflowEntry:
# Create a minimal graph for single node execution
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
node_type = NodeType(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")

View File

@@ -113,7 +113,7 @@ The codebase enforces strict layering via import-linter:
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Ensure the node module is importable under `nodes/<node_type>/`
1. Register in `nodes/node_mapping.py`
1. Add tests in `tests/unit_tests/dify_graph/nodes/`
### Implementing a Custom Layer

View File

@@ -1,9 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@@ -121,8 +121,6 @@ class DefaultValue(BaseModel):
class BaseNodeData(ABC, BaseModel):
# Raw graph payloads are first validated through `NodeConfigDictAdapter`, where
# `node["data"]` is typed as `BaseNodeData` before the concrete node class is known.
# `type` therefore accepts downstream string node kinds; unknown node implementations
# are rejected later when the node factory resolves the node registry.
# At that boundary, node-specific fields are still "extra" relative to this shared DTO,
# and persisted templates/workflows also carry undeclared compatibility keys such as
# `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive

View File

@@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel):
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, downstream response node)
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data

View File

@@ -1,5 +1,4 @@
from enum import StrEnum
from typing import ClassVar, TypeAlias
class NodeState(StrEnum):
@@ -34,71 +33,56 @@ class SystemVariableKey(StrEnum):
INVOKE_FROM = "invoke_from"
NodeType: TypeAlias = str
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
@property
def is_trigger_node(self) -> bool:
"""Check if this node type is a trigger node."""
return self in [
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class BuiltinNodeTypes:
"""Built-in node type string constants.
`node_type` values are plain strings throughout the graph runtime. This namespace
only exposes the built-in values shipped by `dify_graph`; downstream packages can
use additional strings without extending this class.
"""
START: ClassVar[NodeType] = "start"
END: ClassVar[NodeType] = "end"
ANSWER: ClassVar[NodeType] = "answer"
LLM: ClassVar[NodeType] = "llm"
KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval"
IF_ELSE: ClassVar[NodeType] = "if-else"
CODE: ClassVar[NodeType] = "code"
TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform"
QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier"
HTTP_REQUEST: ClassVar[NodeType] = "http-request"
TOOL: ClassVar[NodeType] = "tool"
DATASOURCE: ClassVar[NodeType] = "datasource"
VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner"
LOOP: ClassVar[NodeType] = "loop"
LOOP_START: ClassVar[NodeType] = "loop-start"
LOOP_END: ClassVar[NodeType] = "loop-end"
ITERATION: ClassVar[NodeType] = "iteration"
ITERATION_START: ClassVar[NodeType] = "iteration-start"
PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor"
VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner"
DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor"
LIST_OPERATOR: ClassVar[NodeType] = "list-operator"
AGENT: ClassVar[NodeType] = "agent"
HUMAN_INPUT: ClassVar[NodeType] = "human-input"
BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = (
BuiltinNodeTypes.START,
BuiltinNodeTypes.END,
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.LLM,
BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL,
BuiltinNodeTypes.IF_ELSE,
BuiltinNodeTypes.CODE,
BuiltinNodeTypes.TEMPLATE_TRANSFORM,
BuiltinNodeTypes.QUESTION_CLASSIFIER,
BuiltinNodeTypes.HTTP_REQUEST,
BuiltinNodeTypes.TOOL,
BuiltinNodeTypes.DATASOURCE,
BuiltinNodeTypes.VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR,
BuiltinNodeTypes.LOOP,
BuiltinNodeTypes.LOOP_START,
BuiltinNodeTypes.LOOP_END,
BuiltinNodeTypes.ITERATION,
BuiltinNodeTypes.ITERATION_START,
BuiltinNodeTypes.PARAMETER_EXTRACTOR,
BuiltinNodeTypes.VARIABLE_ASSIGNER,
BuiltinNodeTypes.DOCUMENT_EXTRACTOR,
BuiltinNodeTypes.LIST_OPERATOR,
BuiltinNodeTypes.AGENT,
BuiltinNodeTypes.HUMAN_INPUT,
)
@property
def is_start_node(self) -> bool:
"""Check if this node type can serve as a workflow entry point."""
return self in [
NodeType.START,
NodeType.DATASOURCE,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class NodeExecutionType(StrEnum):
@@ -252,6 +236,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"

View File

@@ -83,6 +83,50 @@ class Graph:
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, NodeConfigDict],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
if node_data.type.is_start_node:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
@@ -257,15 +301,15 @@ class Graph:
*,
graph_config: Mapping[str, object],
node_factory: NodeFactory,
root_node_id: str,
root_node_id: str | None = None,
skip_validation: bool = False,
) -> Graph:
"""
Initialize a graph with an explicit execution entry point.
Initialize graph
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: active root node id
:param root_node_id: root node id
:return: graph instance
"""
# Parse configs
@@ -283,8 +327,8 @@ class Graph:
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)

View File

@@ -4,7 +4,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType
from dify_graph.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@@ -71,7 +71,7 @@ class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START)
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
@@ -86,7 +86,7 @@ class _RootNodeValidator:
)
return issues
node_type = root_node.node_type
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
@@ -114,9 +114,45 @@ class GraphValidator:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)

View File

@@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
"""
from .coordinator import ResponseStreamCoordinator
from .session import RESPONSE_SESSION_NODE_TYPES
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
__all__ = ["ResponseStreamCoordinator"]

View File

@@ -3,34 +3,19 @@ Internal response session management for response coordinator.
This module contains the private ResponseSession class used internally
by ResponseStreamCoordinator to manage streaming sessions.
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
can opt additional response-capable node types into session creation without
patching the coordinator.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol, cast
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.nodes.answer.answer_node import AnswerNode
from dify_graph.nodes.base.template import Template
from dify_graph.nodes.end.end_node import EndNode
from dify_graph.nodes.knowledge_index import KnowledgeIndexNode
from dify_graph.runtime.graph_runtime_state import NodeProtocol
class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
"""Structural contract required from nodes that can open a response session."""
def get_streaming_template(self) -> Template: ...
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
BuiltinNodeTypes.ANSWER,
BuiltinNodeTypes.END,
]
@dataclass
class ResponseSession:
"""
@@ -48,9 +33,10 @@ class ResponseSession:
"""
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
and which implements `get_streaming_template()`.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Args:
node: Node from the materialized workflow graph.
@@ -61,22 +47,11 @@ class ResponseSession:
Raises:
TypeError: If node is not a supported response node type.
"""
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
raise TypeError(
"ResponseSession.from_node only supports node types in "
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
)
response_node = cast(_ResponseSessionNodeProtocol, node)
try:
template = response_node.get_streaming_template()
except AttributeError as exc:
raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
return cls(
node_id=node.id,
template=template,
template=node.get_streaming_template(),
)
def is_complete(self) -> bool:

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities import AgentNodeStrategyInit
from dify_graph.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -12,8 +13,8 @@ from .base import GraphNodeEventBase
class NodeRunStartedEvent(GraphNodeEventBase):
node_title: str
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
extras: dict[str, object] = Field(default_factory=dict)
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""

View File

@@ -1,9 +1,9 @@
from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from dify_graph.entities.pause_reason import PauseReason
from dify_graph.file import File
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
@@ -13,7 +13,7 @@ from .base import NodeEventBase
class RunRetrieverResourceEvent(NodeEventBase):
retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources")
retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
context_files: list[File] | None = Field(default=None, description="context files")

View File

@@ -1,3 +1,3 @@
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.enums import NodeType
__all__ = ["BuiltinNodeTypes"]
__all__ = ["NodeType"]

View File

@@ -0,0 +1,3 @@
from .agent_node import AgentNode
__all__ = ["AgentNode"]

View File

@@ -0,0 +1,761 @@
from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from packaging.version import Version
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
ToolInvokeMessage,
ToolParameter,
ToolProviderType,
)
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from dify_graph.enums import (
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
AgentLogEvent,
NodeEventBase,
NodeRunResult,
StreamChunkEvent,
StreamCompletedEvent,
)
from dify_graph.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayFileSegment, StringSegment
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
AgentInputTypeError,
AgentInvocationError,
AgentMessageTransformError,
AgentNodeError,
AgentVariableNotFoundError,
AgentVariableTypeError,
ToolFileNotFoundError,
)
if TYPE_CHECKING:
from core.agent.strategy.plugin import PluginAgentStrategy
from core.plugin.entities.request import InvokeCredentials
class AgentNode(Node[AgentNodeData]):
"""
Agent Node
"""
node_type = NodeType.AGENT
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEventBase, None, None]:
from core.plugin.impl.exc import PluginDaemonClientSideError
dify_ctx = self.require_dify_context()
try:
strategy = get_plugin_agent_strategy(
tenant_id=dify_ctx.tenant_id,
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
agent_strategy_name=self.node_data.agent_strategy_name,
)
except Exception as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
error=f"Failed to get agent strategy: {str(e)}",
),
)
return
agent_parameters = strategy.get_parameters()
# get parameters
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self.node_data,
for_log=True,
strategy=strategy,
)
credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = strategy.invoke(
params=parameters,
user_id=dify_ctx.user_id,
app_id=dify_ctx.app_id,
conversation_id=conversation_id.text if conversation_id else None,
credentials=credentials,
)
except Exception as e:
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(error),
)
)
return
try:
yield from self._transform_message(
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": self.node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
f"Failed to transform agent message: {str(e)}", original_error=e
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
error=str(transform_error),
)
)
def _generate_agent_parameters(
self,
*,
agent_parameters: Sequence[AgentStrategyParameter],
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (AgentNodeData): The data associated with the agent node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
runtime_variable_pool = variable_pool
dify_ctx = self.require_dify_context()
tool_runtime = ToolManager.get_agent_tool_runtime(
dify_ctx.tenant_id,
dify_ctx.app_id,
entity,
dify_ctx.invoke_from,
runtime_variable_pool,
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
ToolParameter.ToolParameterForm.FORM
if tool_runtime_params.name in manual_input_params
else tool_runtime_params.form
)
manual_input_value = {}
if tool_runtime.entity.parameters:
manual_input_value = {
key: value for key, value in parameters.items() if key in manual_input_params
}
runtime_parameters = {
**tool_runtime.runtime.runtime_parameters,
**manual_input_value,
}
tool_value.append(
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
history_prompt_messages = []
if node_data.memory:
memory = self._fetch_memory(model_instance)
if memory:
prompt_messages = memory.get_history_prompt_messages(
message_limit=node_data.memory.window.size or None
)
history_prompt_messages = [
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
]
value["history_prompt_messages"] = history_prompt_messages
if model_schema:
# remove structured output feature to support old version agent plugin
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
value["entity"] = model_schema.model_dump(mode="json")
else:
value["entity"] = None
result[parameter_name] = value
return result
def _generate_credentials(
self,
parameters: dict[str, Any],
) -> InvokeCredentials:
"""
Generate credentials based on the given agent parameters.
"""
from core.plugin.entities.request import InvokeCredentials
credentials = InvokeCredentials()
# generate credentials for tools selector
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
if tool.get("credential_id"):
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
except ValidationError:
continue
return credentials
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AgentNodeData,
) -> Mapping[str, Sequence[str]]:
_ = graph_config # Explicitly mark as unused
result: dict[str, Any] = {}
typed_node_data = node_data
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}
return result
@property
def agent_strategy_icon(self) -> str | None:
"""
Get agent strategy icon
:return:
"""
from core.plugin.impl.plugin import PluginInstaller
manager = PluginInstaller()
dify_ctx = self.require_dify_context()
plugins = manager.list_plugins(dify_ctx.tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
dify_ctx = self.require_dify_context()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
dify_ctx = self.require_dify_context()
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
)
model_name = value.get("model", "")
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM, model=model_name
)
provider_name = provider_model_bundle.configuration.provider.provider
model_type_instance = provider_model_bundle.model_type_instance
model_instance = ModelManager().get_model_instance(
tenant_id=dify_ctx.tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_instance, model_schema
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
if model_schema.features:
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
try:
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
except ValueError:
model_schema.features.remove(feature)
return model_schema
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Filter MCP type tool
:param strategy: plugin agent strategy
:param tool: tool
:return: filtered tool dict
"""
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_type: NodeType,
node_id: str,
node_execution_id: str,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json_list: list[dict | list] = []
agent_logs: list[AgentLogEvent] = []
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
llm_usage = LLMUsage.empty_usage()
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileNotFoundError(tool_file_id)
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
if node_type == NodeType.AGENT:
if isinstance(message.message.json_object, dict):
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
agent_execution_metadata = {
WorkflowNodeExecutionMetadataKey(key): value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
else:
msg_metadata = {}
llm_usage = LLMUsage.empty_usage()
agent_execution_metadata = {}
if message.message.json_object:
json_list.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise AgentVariableTypeError(
"When 'stream' is True, 'variable_value' must be a string.",
variable_name=variable_name,
expected_type="str",
actual_type=type(variable_value).__name__,
)
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise AgentNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
agent_log = AgentLogEvent(
message_id=message.message.id,
node_execution_id=node_execution_id,
parent_id=message.message.parent_id,
error=message.message.error,
status=message.message.status.value,
data=message.message.data,
label=message.message.label,
metadata=message.message.metadata,
node_id=node_id,
)
# check if the agent log is already in the list
for log in agent_logs:
if log.message_id == agent_log.message_id:
# update the log
log.data = agent_log.data
log.status = agent_log.status
log.error = agent_log.error
log.label = agent_log.label
log.metadata = agent_log.metadata
break
else:
agent_logs.append(agent_log)
yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any] | list[Any]] = []
# Step 1: append each agent log as its own dict.
if agent_logs:
for log in agent_logs:
json_output.append(
{
"id": log.message_id,
"parent_id": log.parent_id,
"error": log.error,
"status": log.status,
"data": log.data,
"label": log.label,
"metadata": log.metadata,
"node_id": log.node_id,
}
)
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json_list:
json_output.extend(json_list)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"text": text,
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
**variables,
},
metadata={
**agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
llm_usage=llm_usage,
)
)

View File

@@ -6,14 +6,14 @@ from pydantic import BaseModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
from dify_graph.enums import NodeType
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
agent_strategy_provider_name: str
type: NodeType = NodeType.AGENT
agent_strategy_provider_name: str # redundancy
agent_strategy_name: str
agent_strategy_label: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version

View File

@@ -119,3 +119,14 @@ class AgentVariableTypeError(AgentNodeError):
self.expected_type = expected_type
self.actual_type = actual_type
super().__init__(message)
class AgentMaxIterationError(AgentNodeError):
"""Exception raised when the agent exceeds the maximum iteration limit."""
def __init__(self, max_iteration: int):
self.max_iteration = max_iteration
super().__init__(
f"Agent exceeded the maximum iteration limit of {max_iteration}. "
f"The agent was unable to complete the task within the allowed number of iterations."
)

Some files were not shown because too many files have changed in this diff Show More