Compare commits

..

30 Commits

Author SHA1 Message Date
crazywoola
f101956d0f fix: remove proxy 2024-12-20 18:33:45 +08:00
crazywoola
5c5221f2cc revert: these 2 settings 2024-12-20 17:53:33 +08:00
zxhlyh
ef7e47d162 fix: rerank switch (#11897) 2024-12-20 16:12:34 +08:00
github-actions[bot]
4211b9abbd chore: translate i18n files (#11892)
Co-authored-by: zxhlyh <16177003+zxhlyh@users.noreply.github.com>
2024-12-20 16:12:01 +08:00
zxhlyh
0c0120ef27 Feat/workflow retry (#11885) 2024-12-20 15:44:37 +08:00
-LAN-
dacd457478 feat: add workflow parallel depth limit configuration (#11460)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2024-12-20 14:52:20 +08:00
yihong
7b03a0316d fix: better memory usage from 800+ to 500+ (#11796)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-20 14:51:43 +08:00
Novice
52201d95b1 chore: add retry index migration (#11887)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:40:33 +08:00
github-actions[bot]
e2cde628bb chore: translate i18n files (#11855)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-20 14:19:47 +08:00
AkaraChen
3335fa78fc fix: node 22 build (#11883) 2024-12-20 14:14:27 +08:00
Novice
7abc7fa573 Feat: Retry on node execution errors (#11871)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:14:06 +08:00
Novice
f6247fe67c Feat: Add partial success status to the app log (#11869)
Co-authored-by: Novice Lee <novicelee@NoviPro.local>
2024-12-20 14:13:44 +08:00
-LAN-
996a9135f6 feat(llm_node): support order in text and files (#11837)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-20 14:12:50 +08:00
-LAN-
3599751f93 chore(db): use a better way to export models and remove unused table (#11838)
Signed-off-by: -LAN- <laipz8200@outlook.com>
2024-12-20 14:12:29 +08:00
dependabot[bot]
2d186e1e76 chore(deps): bump nanoid from 3.3.7 to 3.3.8 in /web (#11876)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-12-20 13:46:54 +08:00
Dr.MerdanBay
bb2f46d7cc fix: add safe dictionary access for bedrock credentials (#11860) 2024-12-20 12:13:39 +09:00
yihong
463fbe2680 fix: better gard nan value from numpy for issue #11827 (#11864)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-20 09:28:32 +08:00
傻笑zz
95a7e50137 Fix comfyui tool https (#11859) 2024-12-20 09:27:21 +08:00
非法操作
9d93ad1f16 feat: add gemini-2.0-flash-thinking-exp-1219 (#11863) 2024-12-20 09:26:31 +08:00
stardust
44104797d6 fix: Enhance file type detection in HTTP Request node (#11797)
Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: 谭成 <tancheng.sh@chinatelecom.cn>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2024-12-20 02:21:41 +08:00
傻笑zz
1548501050 fix: comfyui tool supports https (#11823) 2024-12-19 23:05:27 +08:00
crazywoola
de3911e930 Fix/10584 wrong message when no custom tool available in custom tool list (#11851) 2024-12-19 21:19:08 +08:00
yihong
5a8a901560 fix: float values are not json for nan value close #11827 (#11840)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-19 20:50:20 +08:00
yihong
12d45e9114 fix: silicon change its model fix #11844 (#11847)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-19 20:50:09 +08:00
barabicu
d057067543 fix: remove ruff ignore SIM300 (#11810) 2024-12-19 18:30:51 +08:00
sino
560d375e0f feat(ark): add doubao-pro-256k and doubao-embedding-large (#11831) 2024-12-19 17:49:31 +08:00
Agung Besti
3388d6636c add-model-azure-gpt-4o-2024-11-20 (#11803)
Co-authored-by: agungbesti <agung.besti@insignia.co.id>
2024-12-19 12:36:11 +08:00
Charlie.Wei
2624a6dcd0 Fix explore app icon (#11808)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2024-12-18 21:24:21 +08:00
yihong
b5c2785e10 ci: fix config ci and it works (#11807)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-18 20:17:10 +08:00
yihong
493834d45d ci: add config ci more disscuss check #11706 (#11752)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2024-12-18 17:36:36 +08:00
174 changed files with 2676 additions and 597 deletions

View File

@@ -50,6 +50,9 @@ jobs:
- name: Run ModelRuntime
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
- name: Run dify config tests
run: poetry run -C api python dev/pytest/pytest_config_tests.py
- name: Run Tool
run: poetry run -C api bash dev/pytest/pytest_tools.sh

View File

@@ -399,6 +399,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000
WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
MAX_VARIABLE_SIZE=204800
# App configuration

View File

@@ -70,7 +70,6 @@ ignore = [
"SIM113", # eumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"SIM300", # yoda-conditions,
]
[lint.per-file-ignores]

View File

@@ -433,6 +433,11 @@ class WorkflowConfig(BaseSettings):
default=5,
)
WORKFLOW_PARALLEL_DEPTH_LIMIT: PositiveInt = Field(
description="Maximum allowed depth for nested parallel executions",
default=3,
)
MAX_VARIABLE_SIZE: PositiveInt = Field(
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
default=200 * 1024,

View File

@@ -31,7 +31,7 @@ def admin_required(view):
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if dify_config.ADMIN_API_KEY != auth_token:
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
return view(*args, **kwargs)

View File

@@ -6,6 +6,7 @@ from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from configs import dify_config
from controllers.console import api
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
@@ -426,7 +427,21 @@ class ConvertToWorkflowApi(Resource):
}
class WorkflowConfigApi(Resource):
"""Resource for workflow configuration."""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
return {
"parallel_depth_limit": dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
}
api.add_resource(DraftWorkflowApi, "/apps/<uuid:app_id>/workflows/draft")
api.add_resource(WorkflowConfigApi, "/apps/<uuid:app_id>/workflows/draft/config")
api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps/<uuid:app_id>/advanced-chat/workflows/draft/run")
api.add_resource(DraftWorkflowRunApi, "/apps/<uuid:app_id>/workflows/draft/run")
api.add_resource(WorkflowTaskStopApi, "/apps/<uuid:app_id>/workflow-runs/tasks/<string:task_id>/stop")

View File

@@ -5,8 +5,7 @@ from typing import Optional, Union
from controllers.console.app.error import AppNotFoundError
from extensions.ext_database import db
from libs.login import current_user
from models import App
from models.model import AppMode
from models import App, AppMode
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):

View File

@@ -13,6 +13,7 @@ app_fields = {
"name": fields.String,
"mode": fields.String,
"icon": fields.String,
"icon_type": fields.String,
"icon_url": AppIconUrlField,
"icon_background": fields.String,
}

View File

@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):

View File

@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response:
yield node_failed_response
elif isinstance(
event,
QueueNodeRetryEvent,
):
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
yield response
elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

View File

@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
)
)
elif isinstance(event, NodeRunRetryEvent):
self._publish_event(
QueueNodeRetryEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.error,
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
start_index=event.start_index,
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""

View File

@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
ERROR = "error"
PING = "ping"
STOP = "stop"
RETRY = "retry"
class AppQueueEvent(BaseModel):
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
iteration_duration_map: Optional[dict[str, float]] = None
class QueueNodeRetryEvent(AppQueueEvent):
"""QueueNodeRetryEvent entity"""
event: QueueEvent = QueueEvent.RETRY
node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
error: str
retry_index: int # retry index
start_index: int # start index
class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity

View File

@@ -52,6 +52,7 @@ class StreamEvent(Enum):
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
NODE_RETRY = "node_retry"
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
ITERATION_STARTED = "iteration_started"
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
}
class NodeRetryStreamResponse(StreamResponse):
"""
NodeFinishStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
id: str
node_id: str
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[dict] = None
process_data: Optional[dict] = None
outputs: Optional[dict] = None
status: str
error: Optional[str] = None
elapsed_time: float
execution_metadata: Optional[dict] = None
created_at: int
finished_at: int
files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
workflow_run_id: str
data: Data
def to_ignore_detail_dict(self):
return {
"event": self.event.value,
"task_id": self.task_id,
"workflow_run_id": self.workflow_run_id,
"data": {
"id": self.data.id,
"node_id": self.data.node_id,
"node_type": self.data.node_type,
"title": self.data.title,
"index": self.data.index,
"predecessor_node_id": self.data.predecessor_node_id,
"inputs": None,
"process_data": None,
"outputs": None,
"status": self.data.status,
"error": None,
"elapsed_time": self.data.elapsed_time,
"execution_metadata": None,
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"retry_index": self.data.retry_index,
},
}
class ParallelBranchStartStreamResponse(StreamResponse):
"""
ParallelBranchStartStreamResponse entity

View File

@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.index = event.start_index
db.session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
return workflow_node_execution
#################################################
# to stream responses #
#################################################
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
),
)
def _workflow_node_retry_to_stream_response(
self,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
"""
Workflow node finish to stream response.
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:

View File

@@ -1,15 +1,14 @@
import base64
from configs import dify_config
from core.file import file_repository
from core.helper import ssrf_proxy
from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
VideoPromptMessageContent,
)
from extensions.ext_database import db
from extensions.ext_storage import storage
from . import helpers
@@ -41,7 +40,7 @@ def to_prompt_message_content(
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
):
) -> MultiModalPromptMessageContent:
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@@ -70,16 +69,13 @@ def to_prompt_message_content(
def download(f: File, /):
if f.transfer_method == FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
return _download_file_content(tool_file.file_key)
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
return _download_file_content(upload_file.key)
# remote file
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
if f.transfer_method in (FileTransferMethod.TOOL_FILE, FileTransferMethod.LOCAL_FILE):
return _download_file_content(f._storage_key)
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
response.raise_for_status()
return response.content
raise ValueError(f"unsupported transfer method: {f.transfer_method}")
def _download_file_content(path: str, /):
@@ -110,11 +106,9 @@ def _get_encoded_string(f: File, /):
response.raise_for_status()
data = response.content
case FileTransferMethod.LOCAL_FILE:
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
data = _download_file_content(upload_file.key)
data = _download_file_content(f._storage_key)
case FileTransferMethod.TOOL_FILE:
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
data = _download_file_content(tool_file.file_key)
data = _download_file_content(f._storage_key)
encoded_string = base64.b64encode(data).decode("utf-8")
return encoded_string

View File

@@ -1,32 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from models import ToolFile, UploadFile
from .models import File
def get_upload_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(UploadFile).filter(
UploadFile.id == file.related_id,
UploadFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"upload file {file.related_id} not found")
return record
def get_tool_file(*, session: Session, file: File):
if file.related_id is None:
raise ValueError("Missing file related_id")
stmt = select(ToolFile).filter(
ToolFile.id == file.related_id,
ToolFile.tenant_id == file.tenant_id,
)
record = session.scalar(stmt)
if not record:
raise ValueError(f"tool file {file.related_id} not found")
return record

View File

@@ -47,6 +47,38 @@ class File(BaseModel):
mime_type: Optional[str] = None
size: int = -1
# Those properties are private, should not be exposed to the outside.
_storage_key: str
def __init__(
self,
*,
id: Optional[str] = None,
tenant_id: str,
type: FileType,
transfer_method: FileTransferMethod,
remote_url: Optional[str] = None,
related_id: Optional[str] = None,
filename: Optional[str] = None,
extension: Optional[str] = None,
mime_type: Optional[str] = None,
size: int = -1,
storage_key: str,
):
super().__init__(
id=id,
tenant_id=tenant_id,
type=type,
transfer_method=transfer_method,
remote_url=remote_url,
related_id=related_id,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
)
self._storage_key = storage_key
def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json")
return {

View File

@@ -1,6 +1,5 @@
import base64
from extensions.ext_database import db
from libs import rsa
@@ -14,6 +13,7 @@ def obfuscated_token(token: str):
def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db
if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")

View File

@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
)
retries = 0
stream = kwargs.pop("stream", False)
while retries <= max_retries:
try:
if dify_config.SSRF_PROXY_ALL_URL:

View File

@@ -4,6 +4,7 @@ from .message_entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
@@ -27,6 +28,7 @@ __all__ = [
"LLMResultChunkDelta",
"LLMUsage",
"ModelPropertyKey",
"MultiModalPromptMessageContent",
"PromptMessage",
"PromptMessage",
"PromptMessageContent",

View File

@@ -84,10 +84,10 @@ class MultiModalPromptMessageContent(PromptMessageContent):
"""
type: PromptMessageContentType
format: str = Field(..., description="the format of multi-modal file")
base64_data: str = Field("", description="the base64 data of multi-modal file")
url: str = Field("", description="the url of multi-modal file")
mime_type: str = Field(..., description="the mime type of multi-modal file")
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
@computed_field(return_type=str)
@property

View File

@@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
),
),
),
AzureBaseModel(
base_model_name="gpt-4o-2024-11-20",
entity=AIModelEntity(
model="fake-deployment-name",
label=I18nObject(
en_US="fake-deployment-name-label",
),
model_type=ModelType.LLM,
features=[
ModelFeature.AGENT_THOUGHT,
ModelFeature.VISION,
ModelFeature.MULTI_TOOL_CALL,
ModelFeature.STREAM_TOOL_CALL,
],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.MODE: LLMMode.CHAT.value,
ModelPropertyKey.CONTEXT_SIZE: 128000,
},
parameter_rules=[
ParameterRule(
name="temperature",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name="top_p",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name="presence_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name="frequency_penalty",
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=16384),
ParameterRule(
name="seed",
label=I18nObject(zh_Hans="种子", en_US="Seed"),
type="int",
help=AZURE_DEFAULT_PARAM_SEED_HELP,
required=False,
precision=2,
min=0,
max=1,
),
ParameterRule(
name="response_format",
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
type="string",
help=I18nObject(
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
),
required=False,
options=["text", "json_object", "json_schema"],
),
ParameterRule(
name="json_schema",
label=I18nObject(en_US="JSON Schema"),
type="text",
help=I18nObject(
zh_Hans="设置返回的json schemallm将按照它返回",
en_US="Set a response json schema will ensure LLM to adhere it.",
),
required=False,
),
],
pricing=PriceConfig(
input=5.00,
output=15.00,
unit=0.000001,
currency="USD",
),
),
),
AzureBaseModel(
base_model_name="gpt-4-turbo",
entity=AIModelEntity(

View File

@@ -171,6 +171,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4o-2024-11-20
value: gpt-4o-2024-11-20
show_on:
- variable: __model_type
value: llm
- label:
en_US: gpt-4-turbo
value: gpt-4-turbo

View File

@@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -1,11 +1,19 @@
from collections.abc import Mapping
import boto3
from botocore.config import Config
from core.model_runtime.errors.invoke import InvokeBadRequestError
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
region_name = credentials.get("aws_region")
if not region_name:
raise InvokeBadRequestError("aws_region is required")
client_config = Config(region_name=region_name)
aws_access_key_id = credentials.get("aws_access_key_id")
aws_secret_access_key = credentials.get("aws_secret_access_key")
def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"]
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# use aksk to call bedrock
client = boto3.client(

View File

@@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
}
)
modelId = model
region = credentials["aws_region"]
region = credentials.get("aws_region")
# region is a required field
if not region:
raise InvokeBadRequestError("aws_region is required in credentials")
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
rerankingConfiguration = {
"type": "BEDROCK_RERANKING_MODEL",

View File

@@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -1,4 +1,5 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001

View File

@@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

View File

@@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -119,7 +119,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
embeddings.append(result[0].get("embedding"))
return [list(map(float, e)) for e in embeddings]
elif "texts" == text_input_key:
elif text_input_key == "texts":
result = client.run(
replicate_model_version,
input={

View File

@@ -18,7 +18,7 @@ class SiliconflowProvider(ModelProvider):
try:
model_instance = self.get_model_instance(ModelType.LLM)
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials)
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:

View File

@@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

View File

@@ -4,11 +4,10 @@ import json
import logging
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import google.auth.transport.requests
import requests
import vertexai.generative_models as glm
from anthropic import AnthropicVertex, Stream
from anthropic.types import (
ContentBlockDeltaEvent,
@@ -19,8 +18,6 @@ from anthropic.types import (
MessageStreamEvent,
)
from google.api_core import exceptions
from google.cloud import aiplatform
from google.oauth2 import service_account
from PIL import Image
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@@ -47,6 +44,9 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
if TYPE_CHECKING:
import vertexai.generative_models as glm
logger = logging.getLogger(__name__)
@@ -102,6 +102,8 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
from google.oauth2 import service_account
# use Anthropic official SDK references
# - https://github.com/anthropics/anthropic-sdk-python
service_account_key = credentials.get("vertex_service_account_key", "")
@@ -406,13 +408,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return text.rstrip()
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> glm.Tool:
def _convert_tools_to_glm_tool(self, tools: list[PromptMessageTool]) -> "glm.Tool":
"""
Convert tool messages to glm tools
:param tools: tool messages
:return: glm tools
"""
import vertexai.generative_models as glm
return glm.Tool(
function_declarations=[
glm.FunctionDeclaration(
@@ -473,6 +477,10 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
import vertexai.generative_models as glm
from google.cloud import aiplatform
from google.oauth2 import service_account
config_kwargs = model_parameters.copy()
config_kwargs["max_output_tokens"] = config_kwargs.pop("max_tokens_to_sample", None)
@@ -522,7 +530,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
@@ -554,7 +562,7 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return result
def _handle_generate_stream_response(
self, model: str, credentials: dict, response: glm.GenerationResponse, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: "glm.GenerationResponse", prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
@@ -638,13 +646,15 @@ class VertexAiLargeLanguageModel(LargeLanguageModel):
return message_text
def _format_message_to_glm_content(self, message: PromptMessage) -> glm.Content:
def _format_message_to_glm_content(self, message: PromptMessage) -> "glm.Content":
"""
Format a single message into glm.Content for Google API
:param message: one PromptMessage
:return: glm Content representation of message
"""
import vertexai.generative_models as glm
if isinstance(message, UserPromptMessage):
glm_content = glm.Content(role="user", parts=[])

View File

@@ -2,12 +2,9 @@ import base64
import json
import time
from decimal import Decimal
from typing import Optional
from typing import TYPE_CHECKING, Optional
import tiktoken
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.common_entities import I18nObject
@@ -24,6 +21,11 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.vertex_ai._common import _CommonVertexAi
if TYPE_CHECKING:
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
else:
VertexTextEmbeddingModel = None
class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
"""
@@ -48,6 +50,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param input_type: input type
:return: embeddings result
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]
location = credentials["vertex_location"]
@@ -100,6 +106,10 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel):
:param credentials: model credentials
:return:
"""
from google.cloud import aiplatform
from google.oauth2 import service_account
from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel
try:
service_account_key = credentials.get("vertex_service_account_key", "")
project_id = credentials["vertex_project_id"]

View File

@@ -40,6 +40,10 @@ configs: dict[str, ModelConfig] = {
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],
),
"Doubao-pro-256k": ModelConfig(
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
features=[],
),
"Doubao-pro-128k": ModelConfig(
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
features=[ModelFeature.TOOL_CALL],

View File

@@ -12,6 +12,7 @@ class ModelConfig(BaseModel):
ModelConfigs = {
"Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
"Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
}
@@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
if not model_configs:
return ModelConfig(
properties=ModelProperties(
context_size=int(credentials.get("context_size", 0)),
context_size=int(credentials.get("context_size", 4096)),
max_chunks=int(credentials.get("max_chunks", 1)),
)
)

View File

@@ -166,6 +166,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
- label:
en_US: Doubao-pro-256k
value: Doubao-pro-256k
show_on:
- variable: __model_type
value: llm
- label:
en_US: Llama3-8B
value: Llama3-8B
@@ -220,6 +226,12 @@ model_credential_schema:
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: Doubao-embedding-large
value: Doubao-embedding-large
show_on:
- variable: __model_type
value: text-embedding
- label:
en_US: Custom
zh_Hans: 自定义

View File

@@ -1,18 +1,19 @@
import re
from typing import Optional
import jieba
from jieba.analyse import default_tfidf
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
import jieba.analyse
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
import jieba
keywords = jieba.analyse.extract_tags(
sentence=text,
topK=max_keywords_per_chunk,
@@ -22,6 +23,8 @@ class JiebaKeywordTableHandler:
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
results = set()
for token in tokens:
results.add(token)

View File

@@ -6,10 +6,8 @@ from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg
import nltk
import numpy
import oracledb
from nltk.corpus import stopwords
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -202,6 +200,10 @@ class OracleVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# lazy import
import nltk
from nltk.corpus import stopwords
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
score_threshold = float(kwargs.get("score_threshold") or 0.0)

View File

@@ -65,6 +65,11 @@ class CacheEmbedding(Embeddings):
for vector in embedding_result.embeddings:
try:
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
if np.isnan(normalized_embedding).any():
# for issue #11827 float values are not json compliant
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
continue
embedding_queue_embeddings.append(normalized_embedding)
except IntegrityError:
db.session.rollback()
@@ -111,6 +116,8 @@ class CacheEmbedding(Embeddings):
embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")

View File

@@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
ws = websocket.WebSocket()
base_url = URL(credentials.get("base_url"))
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
ws_protocol = "ws"
if base_url.scheme == "https":
ws_protocol = "wss"
ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123"
try:
ws.connect(ws_address)

View File

@@ -40,7 +40,10 @@ class ComfyUiClient:
def open_websocket_connection(self) -> tuple[WebSocket, str]:
client_id = str(uuid.uuid4())
ws = WebSocket()
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
ws_protocol = "ws"
if self.base_url.scheme == "https":
ws_protocol = "wss"
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
ws.connect(ws_address)
return ws, client_id

View File

@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
error: Optional[str] = None # error message if status is failed
error_type: Optional[str] = None # error type if status is failed
# single step node run retry
retry_index: int = 0

View File

@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")
class NodeRunRetryEvent(BaseNodeEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
start_at: datetime = Field(..., description="retry start time")
start_index: int = Field(..., description="retry start index")
###########################################
# Parallel Branch Events
###########################################

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from configs import dify_config
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes import NodeType
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
@@ -170,7 +171,9 @@ class Graph(BaseModel):
for parallel in parallel_mapping.values():
if parallel.parent_parallel_id:
cls._check_exceed_parallel_limit(
parallel_mapping=parallel_mapping, level_limit=3, parent_parallel_id=parallel.parent_parallel_id
parallel_mapping=parallel_mapping,
level_limit=dify_config.WORKFLOW_PARALLEL_DEPTH_LIMIT,
parent_parallel_id=parallel.parent_parallel_id,
)
# init answer stream generate routes

View File

@@ -5,6 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
@@ -581,7 +583,7 @@ class GraphEngine:
def _run_node(
self,
node_instance: BaseNode,
node_instance: BaseNode[BaseNodeData],
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@@ -607,36 +609,121 @@ class GraphEngine:
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
shoudl_continue_retry = True
while shoudl_continue_retry and retries <= max_retries:
try:
# run node
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
try:
# run node
generator = node_instance.run()
for item in generator:
if isinstance(item, GraphEngineEvent):
if isinstance(item, BaseIterationEvent):
# add parallel info to iteration event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and run_result.outputs
and not node_instance.should_continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
if node_instance.should_retry and retries < max_retries:
retries += 1
self.graph_runtime_state.node_run_steps += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
error=run_result.error,
retry_index=retries,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
start_at=retry_start_at,
start_index=self.graph_runtime_state.node_run_steps,
)
time.sleep(retry_interval)
continue
route_node_state.set_finished(run_result=run_result)
yield item
else:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if node_instance.should_continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
node_instance,
item.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
route_node_state.node_run_result = run_result
route_node_state.status = RouteNodeState.Status.EXCEPTION
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
@@ -645,21 +732,23 @@ class GraphEngine:
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
parallel_start_node_id
)
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
@@ -670,108 +759,59 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
shoudl_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
node_instance.node_id
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)
if run_result.llm_usage:
# use the latest usage
self.graph_runtime_state.llm_usage += run_result.llm_usage
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_instance.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
# add parallel info to run result metadata
if parallel_id and parallel_start_node_id:
if not run_result.metadata:
run_result.metadata = {}
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
if parent_parallel_id and parent_parallel_start_node_id:
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
parent_parallel_start_node_id
)
yield NodeRunSucceededEvent(
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
retriever_resources=item.retriever_resources,
context=item.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
except GenerateTaskStoppedError:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
id=node_instance.id,
node_id=node_instance.node_id,
node_type=node_instance.node_type,
node_data=node_instance.node_data,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
)
return
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

View File

@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
return self
class RetryConfig(BaseModel):
"""node retry config"""
max_retries: int = 0 # max retry times
retry_interval: int = 0 # retry interval in milliseconds
retry_enabled: bool = False # whether retry is enabled
@property
def retry_interval_seconds(self) -> float:
return self.retry_interval / 1000
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
def default_value_dict(self):

View File

@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from models.workflow import WorkflowNodeExecutionStatus
@@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
bool: if should continue on error
"""
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
@property
def should_retry(self) -> bool:
"""judge if should retry
Returns:
bool: if should retry
"""
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE

View File

@@ -8,12 +8,6 @@ import docx
import pandas as pd
import pypdfium2 # type: ignore
import yaml # type: ignore
from unstructured.partition.api import partition_via_api
from unstructured.partition.email import partition_email
from unstructured.partition.epub import partition_epub
from unstructured.partition.msg import partition_msg
from unstructured.partition.ppt import partition_ppt
from unstructured.partition.pptx import partition_pptx
from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
@@ -256,6 +250,8 @@ def _extract_text_from_excel(file_content: bytes) -> str:
def _extract_text_from_ppt(file_content: bytes) -> str:
from unstructured.partition.ppt import partition_ppt
try:
with io.BytesIO(file_content) as file:
elements = partition_ppt(file=file)
@@ -265,6 +261,9 @@ def _extract_text_from_ppt(file_content: bytes) -> str:
def _extract_text_from_pptx(file_content: bytes) -> str:
from unstructured.partition.api import partition_via_api
from unstructured.partition.pptx import partition_pptx
try:
if dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY:
with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file:
@@ -287,6 +286,8 @@ def _extract_text_from_pptx(file_content: bytes) -> str:
def _extract_text_from_epub(file_content: bytes) -> str:
from unstructured.partition.epub import partition_epub
try:
with io.BytesIO(file_content) as file:
elements = partition_epub(file=file)
@@ -296,6 +297,8 @@ def _extract_text_from_epub(file_content: bytes) -> str:
def _extract_text_from_eml(file_content: bytes) -> str:
from unstructured.partition.email import partition_email
try:
with io.BytesIO(file_content) as file:
elements = partition_email(file=file)
@@ -305,6 +308,8 @@ def _extract_text_from_eml(file_content: bytes) -> str:
def _extract_text_from_msg(file_content: bytes) -> str:
from unstructured.partition.msg import partition_msg
try:
with io.BytesIO(file_content) as file:
elements = partition_msg(file=file)

View File

@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]

View File

@@ -1,4 +1,10 @@
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from .event import (
ModelInvokeCompletedEvent,
RunCompletedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
RunStreamChunkEvent,
)
from .types import NodeEvent
__all__ = [
@@ -6,5 +12,6 @@ __all__ = [
"NodeEvent",
"RunCompletedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"RunStreamChunkEvent",
]

View File

@@ -1,7 +1,10 @@
from datetime import datetime
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RunCompletedEvent(BaseModel):
@@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel):
text: str
usage: LLMUsage
finish_reason: str | None = None
class RunRetryEvent(BaseModel):
"""Node Run Retry event"""
error: str = Field(..., description="error")
retry_index: int = Field(..., description="Retry attempt number")
start_at: datetime = Field(..., description="Retry start time")
class SingleStepRetryEvent(BaseModel):
"""Single step retry event"""
status: str = WorkflowNodeExecutionStatus.RETRY.value
inputs: dict | None = Field(..., description="input")
error: str = Field(..., description="error")
outputs: dict = Field(..., description="output")
retry_index: int = Field(..., description="Retry attempt number")
error: str = Field(..., description="error")
elapsed_time: float = Field(..., description="elapsed time")
execution_metadata: dict | None = Field(..., description="execution metadata")

View File

@@ -45,6 +45,7 @@ class Executor:
headers: dict[str, str]
auth: HttpRequestNodeAuthorization
timeout: HttpRequestNodeTimeout
max_retries: int
boundary: str
@@ -54,6 +55,7 @@ class Executor:
node_data: HttpRequestNodeData,
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -73,6 +75,7 @@ class Executor:
self.files = None
self.data = None
self.json = None
self.max_retries = max_retries
# init template
self.variable_pool = variable_pool
@@ -241,6 +244,7 @@ class Executor:
"params": self.params,
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
"follow_redirects": True,
"max_retries": self.max_retries,
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:

View File

@@ -1,4 +1,5 @@
import logging
import mimetypes
from collections.abc import Mapping, Sequence
from typing import Any
@@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
},
},
"retry_config": {
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
"retry_interval": 0.5 * (2**2),
"retry_enabled": True,
},
}
def _run(self) -> NodeRunResult:
@@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
process_data["request"] = http_executor.to_log()
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
if not response.response.is_success and self.should_continue_on_error:
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@@ -156,20 +163,24 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
def extract_files(self, url: str, response: Response) -> list[File]:
"""
Extract files from response
Extract files from response by checking both Content-Type header and URL
"""
files = []
is_file = response.is_file
content_type = response.content_type
content = response.content
if is_file and content_type:
if is_file:
# Guess file extension from URL or Content-Type header
filename = url.split("?")[0].split("/")[-1] or ""
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
tool_file = ToolFileManager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
file_binary=content,
mimetype=content_type,
mimetype=mime_type,
)
mapping = {

View File

@@ -50,6 +50,7 @@ class PromptConfig(BaseModel):
class LLMNodeChatModelMessage(ChatModelMessage):
text: str = ""
jinja2_text: Optional[str] = None

View File

@@ -145,8 +145,8 @@ class LLMNode(BaseNode[LLMNodeData]):
query = query_variable.text
prompt_messages, stop = self._fetch_prompt_messages(
user_query=query,
user_files=files,
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
@@ -545,8 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_prompt_messages(
self,
*,
user_query: str | None = None,
user_files: Sequence["File"],
sys_query: str | None = None,
sys_files: Sequence["File"],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
@@ -562,7 +562,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@@ -581,14 +581,14 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if user_query:
if sys_query:
message = LLMNodeChatModelMessage(
text=user_query,
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
_handle_list_messages(
self._handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@@ -635,24 +635,27 @@ class LLMNode(BaseNode[LLMNodeData]):
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if user_query:
if sys_query:
if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query)
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = user_query + "\n" + content_item.data
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
if vision_enabled and user_files:
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in user_files:
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
@@ -662,7 +665,7 @@ class LLMNode(BaseNode[LLMNodeData]):
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Filter prompt messages
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
@@ -846,6 +849,58 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
def _handle_list_messages(
self,
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
contents: list[PromptMessageContent] = []
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
contents.append(TextPromptMessageContent(data=result_text))
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
contents.append(file_content)
else:
plain_text = segment.markdown.strip()
if plain_text:
contents.append(TextPromptMessageContent(data=plain_text))
prompt_message = _combine_message_content_with_role(contents=contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
match role:
@@ -880,68 +935,6 @@ def _render_jinja2_message(
return result_text
def _handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
) -> Sequence[PromptMessage]:
prompt_messages = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinjia2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
if isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> int:

View File

@@ -86,10 +86,10 @@ class QuestionClassifierNode(LLMNode):
)
prompt_messages, stop = self._fetch_prompt_messages(
prompt_template=prompt_template,
user_query=query,
sys_query=query,
memory=memory,
model_config=model_config,
user_files=files,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,

View File

@@ -1,18 +1,5 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
from dify_app import DifyApp
POSTGRES_INDEXES_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
db = SQLAlchemy(metadata=metadata)
from models import db
def init_app(app: DifyApp):

View File

@@ -3,4 +3,3 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from events import event_handlers # noqa: F401
from models import account, dataset, model, source, task, tool, tools, web # noqa: F401

View File

@@ -139,6 +139,7 @@ def _build_from_local_file(
remote_url=row.source_url,
related_id=mapping.get("upload_file_id"),
size=row.size,
storage_key=row.key,
)
@@ -168,6 +169,7 @@ def _build_from_remote_url(
mime_type=mime_type,
extension=extension,
size=file_size,
storage_key="",
)
@@ -220,6 +222,7 @@ def _build_from_tool_file(
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
storage_key=tool_file.file_key,
)

View File

@@ -85,7 +85,7 @@ message_detail_fields = {
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
model_config_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
@@ -166,6 +166,7 @@ conversation_with_summary_fields = {
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"status_count": fields.Nested(status_count_fields),
}
conversation_with_summary_pagination_fields = {

View File

@@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
"retry_index": fields.Integer,
}
advanced_chat_workflow_run_for_list_fields = {
@@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
"created_at": TimestampField,
"finished_at": TimestampField,
"exceptions_count": fields.Integer,
"retry_index": fields.Integer,
}
advanced_chat_workflow_run_pagination_fields = {
@@ -79,6 +81,17 @@ workflow_run_detail_fields = {
"exceptions_count": fields.Integer,
}
retry_event_field = {
"error": fields.String,
"retry_index": fields.Integer,
"inputs": fields.Raw(attribute="inputs"),
"elapsed_time": fields.Float,
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
"status": fields.String,
"outputs": fields.Raw(attribute="outputs"),
}
workflow_run_node_execution_fields = {
"id": fields.String,
"index": fields.Integer,
@@ -99,6 +112,7 @@ workflow_run_node_execution_fields = {
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
"finished_at": TimestampField,
"retry_events": fields.List(fields.Nested(retry_event_field)),
}
workflow_run_node_execution_list_fields = {

View File

@@ -13,7 +13,7 @@ from typing import Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
from flask_restful import fields # type: ignore
from flask_restful import fields
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator

View File

@@ -0,0 +1,39 @@
"""remove unused tool_providers
Revision ID: 11b07f66c737
Revises: cf8f4fc45278
Create Date: 2024-12-19 17:46:25.780116
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '11b07f66c737'
down_revision = 'cf8f4fc45278'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tool_providers')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tool_providers',
sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
sa.Column('tenant_id', sa.UUID(), autoincrement=False, nullable=False),
sa.Column('tool_name', sa.VARCHAR(length=40), autoincrement=False, nullable=False),
sa.Column('encrypted_credentials', sa.TEXT(), autoincrement=False, nullable=True),
sa.Column('is_enabled', sa.BOOLEAN(), server_default=sa.text('false'), autoincrement=False, nullable=False),
sa.Column('created_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.Column('updated_at', postgresql.TIMESTAMP(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint('id', name='tool_provider_pkey'),
sa.UniqueConstraint('tenant_id', 'tool_name', name='unique_tool_provider_tool_name')
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,33 @@
"""add retry_index field to node-execution model
Revision ID: e1944c35e15e
Revises: 11b07f66c737
Create Date: 2024-12-20 06:28:30.287197
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e1944c35e15e'
down_revision = '11b07f66c737'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
batch_op.drop_column('retry_index')
# ### end Alembic commands ###

View File

@@ -1,53 +1,187 @@
from .account import Account, AccountIntegrate, InvitationCode, Tenant
from .dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
from .account import (
Account,
AccountIntegrate,
AccountStatus,
InvitationCode,
Tenant,
TenantAccountJoin,
TenantAccountJoinRole,
TenantAccountRole,
TenantStatus,
)
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from .dataset import (
AppDatasetJoin,
Dataset,
DatasetCollectionBinding,
DatasetKeywordTable,
DatasetPermission,
DatasetPermissionEnum,
DatasetProcessRule,
DatasetQuery,
Document,
DocumentSegment,
Embedding,
ExternalKnowledgeApis,
ExternalKnowledgeBindings,
TidbAuthBinding,
Whitelist,
)
from .engine import db
from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom
from .model import (
ApiRequest,
ApiToken,
App,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppMode,
AppModelConfig,
Conversation,
DatasetRetrieverResource,
DifySetup,
EndUser,
IconType,
InstalledApp,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
OperationLog,
RecommendedApp,
Site,
Tag,
TagBinding,
TraceAppConfig,
UploadFile,
)
from .source import DataSourceOauthBinding
from .tools import ToolFile
from .provider import (
LoadBalancingModelConfig,
Provider,
ProviderModel,
ProviderModelSetting,
ProviderOrder,
ProviderQuotaType,
ProviderType,
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
ApiToolProvider,
BuiltinToolProvider,
PublishedAppTool,
ToolConversationVariables,
ToolFile,
ToolLabelBinding,
ToolModelInvoke,
WorkflowToolProvider,
)
from .web import PinnedConversation, SavedMessage
from .workflow import (
ConversationVariable,
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
WorkflowType,
)
__all__ = [
"APIBasedExtension",
"APIBasedExtensionPoint",
"Account",
"AccountIntegrate",
"AccountStatus",
"ApiRequest",
"ApiToken",
"ApiToolProvider", # Added
"App",
"AppAnnotationHitHistory",
"AppAnnotationSetting",
"AppDatasetJoin",
"AppMode",
"AppModelConfig",
"BuiltinToolProvider", # Added
"CeleryTask",
"CeleryTaskSet",
"Conversation",
"ConversationVariable",
"CreatedByRole",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
"DatasetCollectionBinding",
"DatasetKeywordTable",
"DatasetPermission",
"DatasetPermissionEnum",
"DatasetProcessRule",
"DatasetQuery",
"DatasetRetrieverResource",
"DifySetup",
"Document",
"DocumentSegment",
"Embedding",
"EndUser",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
"InstalledApp",
"InvitationCode",
"LoadBalancingModelConfig",
"Message",
"MessageAgentThought",
"MessageAnnotation",
"MessageChain",
"MessageFeedback",
"MessageFile",
"OperationLog",
"PinnedConversation",
"Provider",
"ProviderModel",
"ProviderModelSetting",
"ProviderOrder",
"ProviderQuotaType",
"ProviderType",
"PublishedAppTool",
"RecommendedApp",
"SavedMessage",
"Site",
"Tag",
"TagBinding",
"Tenant",
"TenantAccountJoin",
"TenantAccountJoinRole",
"TenantAccountRole",
"TenantDefaultModel",
"TenantPreferredModelProvider",
"TenantStatus",
"TidbAuthBinding",
"ToolConversationVariables",
"ToolFile",
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"UploadFile",
"UserFrom",
"Whitelist",
"Workflow",
"WorkflowAppLog",
"WorkflowAppLogCreatedFrom",
"WorkflowNodeExecution",
"WorkflowNodeExecutionStatus",
"WorkflowNodeExecutionTriggeredFrom",
"WorkflowRun",
"WorkflowRunStatus",
"WorkflowRunTriggeredFrom",
"WorkflowToolProvider",
"WorkflowType",
"db",
]

View File

@@ -3,8 +3,7 @@ import json
from flask_login import UserMixin
from extensions.ext_database import db
from .engine import db
from .types import StringUUID

View File

@@ -1,7 +1,6 @@
import enum
from extensions.ext_database import db
from .engine import db
from .types import StringUUID

View File

@@ -15,10 +15,10 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage
from .account import Account
from .engine import db
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID

13
api/models/engine.py Normal file
View File

@@ -0,0 +1,13 @@
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
POSTGRES_INDEXES_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
db = SQLAlchemy(metadata=metadata)

View File

@@ -16,11 +16,12 @@ from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser
from extensions.ext_database import db
from libs.helper import generate_string
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant
from .engine import db
from .types import StringUUID
@@ -560,13 +561,29 @@ class Conversation(db.Model):
@property
def inputs(self):
inputs = self._inputs.copy()
# Convert file mapping to File object
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
inputs[key] = File.model_validate(value)
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
):
inputs[key] = [File.model_validate(item) for item in value]
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
return inputs
@inputs.setter
@@ -679,6 +696,29 @@ class Conversation(db.Model):
return {"like": like, "dislike": dislike}
@property
def status_count(self):
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
status_counts = {
WorkflowRunStatus.SUCCEEDED: 0,
WorkflowRunStatus.FAILED: 0,
WorkflowRunStatus.PARTIAL_SUCCESSED: 0,
}
for message in messages:
if message.workflow_run:
status_counts[message.workflow_run.status] += 1
return (
{
"success": status_counts[WorkflowRunStatus.SUCCEEDED],
"failed": status_counts[WorkflowRunStatus.FAILED],
"partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCESSED],
}
if messages
else None
)
@property
def first_message(self):
return db.session.query(Message).filter(Message.conversation_id == self.id).first()
@@ -758,12 +798,25 @@ class Message(db.Model):
def inputs(self):
inputs = self._inputs.copy()
for key, value in inputs.items():
# NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now.
from factories import file_factory
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
inputs[key] = File.model_validate(value)
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY for item in value
):
inputs[key] = [File.model_validate(item) for item in value]
inputs[key] = []
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
return inputs
@inputs.setter

View File

@@ -1,7 +1,6 @@
from enum import Enum
from extensions.ext_database import db
from .engine import db
from .types import StringUUID

View File

@@ -2,8 +2,7 @@ import json
from sqlalchemy.dialects.postgresql import JSONB
from extensions.ext_database import db
from .engine import db
from .types import StringUUID

View File

@@ -2,7 +2,7 @@ from datetime import UTC, datetime
from celery import states
from extensions.ext_database import db
from .engine import db
class CeleryTask(db.Model):

View File

@@ -1,47 +0,0 @@
import json
from enum import Enum
from extensions.ext_database import db
from .types import StringUUID
class ToolProviderName(Enum):
SERPAPI = "serpapi"
@staticmethod
def value_of(value):
for member in ToolProviderName:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class ToolProvider(db.Model):
__tablename__ = "tool_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),
db.UniqueConstraint("tenant_id", "tool_name", name="unique_tool_provider_tool_name"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
tool_name = db.Column(db.String(40), nullable=False)
encrypted_credentials = db.Column(db.Text, nullable=True)
is_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def credentials_is_set(self):
"""
Returns True if the encrypted_config is not None, indicating that the token is set.
"""
return self.encrypted_credentials is not None
@property
def credentials(self):
"""
Returns the decrypted config.
"""
return json.loads(self.encrypted_credentials) if self.encrypted_credentials is not None else None

View File

@@ -8,8 +8,8 @@ from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
from extensions.ext_database import db
from .engine import db
from .model import Account, App, Tenant
from .types import StringUUID
@@ -82,7 +82,7 @@ class PublishedAppTool(db.Model):
return I18nObject(**json.loads(self.description))
@property
def app(self) -> App:
def app(self):
return db.session.query(App).filter(App.id == self.app_id).first()
@@ -201,10 +201,6 @@ class WorkflowToolProvider(db.Model):
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def schema_type(self) -> ApiProviderSchemaType:
return ApiProviderSchemaType.value_of(self.schema_type_str)
@property
def user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()

View File

@@ -1,5 +1,4 @@
from extensions.ext_database import db
from .engine import db
from .model import Message
from .types import StringUUID

View File

@@ -12,12 +12,12 @@ import contexts
from constants import HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Variable
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from models.enums import CreatedByRole
from .account import Account
from .engine import db
from .types import StringUUID
@@ -399,7 +399,7 @@ class WorkflowRun(db.Model):
graph = db.Column(db.Text)
inputs = db.Column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[str] = mapped_column(sa.Text, default="{}")
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
@@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum):
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
RETRY = "retry"
@classmethod
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
@@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model):
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime)
retry_index = db.Column(db.Integer, server_default=db.text("0"))
@property
def created_by_account(self):

View File

@@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.event import SingleStepRetryEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@@ -220,56 +221,99 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
retries = 0
max_retries = 0
should_retry = True
retry_events = []
try:
node_instance, generator = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
while retries <= max_retries and should_retry:
retry_start_at = time.perf_counter()
node_instance, generator = WorkflowEntry.single_step_run(
workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
user_id=account.id,
)
node_instance = cast(BaseNode[BaseNodeData], node_instance)
max_retries = (
node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
)
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
node_run_result: NodeRunResult | None = None
for event in generator:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
and node_instance.node_type == NodeType.HTTP_REQUEST
and node_run_result.outputs
and not node_instance.should_continue_on_error
):
node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
should_retry = False
else:
if node_instance.should_retry:
node_run_result.status = WorkflowNodeExecutionStatus.RETRY
retries += 1
node_run_result.retry_index = retries
retry_events.append(
SingleStepRetryEvent(
inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
if node_run_result.inputs
else None,
error=node_run_result.error,
outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
if node_run_result.outputs
else None,
retry_index=node_run_result.retry_index,
elapsed_time=time.perf_counter() - retry_start_at,
execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
if node_run_result.metadata
else None,
)
)
time.sleep(retry_interval)
else:
should_retry = False
if node_instance.should_continue_on_error:
node_error_args = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
}
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
**node_instance.node_data.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
else:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
)
run_succeeded = node_run_result.status in (
WorkflowNodeExecutionStatus.SUCCEEDED,
WorkflowNodeExecutionStatus.EXCEPTION,
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
node_instance = e.node_instance
run_succeeded = False
@@ -318,6 +362,7 @@ class WorkflowService:
db.session.add(workflow_node_execution)
db.session.commit()
workflow_node_execution.retry_events = retry_events
return workflow_node_execution

View File

@@ -21,13 +21,13 @@ class MockXinferenceClass:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")
if "generate" == model_uid:
if model_uid == "generate":
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if "chat" == model_uid:
if model_uid == "chat":
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if "embedding" == model_uid:
if model_uid == "embedding":
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if "rerank" == model_uid:
if model_uid == "rerank":
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError("404 Not Found")

View File

@@ -34,9 +34,9 @@ def test_api_tool(setup_http_mock):
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
assert response.status_code == 200
assert "/p_param" == response.request.url.path
assert b"query_param=q_param" == response.request.url.query
assert "h_param" == response.request.headers.get("header_param")
assert "application/json" == response.request.headers.get("content-type")
assert "cookie_param=c_param" == response.request.headers.get("cookie")
assert response.request.url.path == "/p_param"
assert response.request.url.query == b"query_param=q_param"
assert response.request.headers.get("header_param") == "h_param"
assert response.request.headers.get("content-type") == "application/json"
assert response.request.headers.get("cookie") == "cookie_param=c_param"
assert "b_param" in response.content.decode()

View File

@@ -384,7 +384,7 @@ def test_mock_404(setup_http_mock):
assert result.outputs is not None
resp = result.outputs
assert 404 == resp.get("status_code")
assert resp.get("status_code") == 404
assert "Not Found" in resp.get("body", "")

View File

@@ -59,6 +59,8 @@ def test_dify_config(example_env_file):
# annotated field with configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
assert config.WORKFLOW_PARALLEL_DEPTH_LIMIT == 3
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.

View File

@@ -136,6 +136,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="",
)
]

View File

@@ -1,34 +1,9 @@
import json
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType, FileUploadConfig
from core.file import File, FileTransferMethod, FileType, FileUploadConfig
from models.workflow import Workflow
def test_file_loads_and_dumps():
file = File(
id="file1",
tenant_id="tenant1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
)
file_dict = file.model_dump()
assert file_dict["dify_model_identity"] == FILE_MODEL_IDENTITY
assert file_dict["type"] == file.type.value
assert isinstance(file_dict["type"], str)
assert file_dict["transfer_method"] == file.transfer_method.value
assert isinstance(file_dict["transfer_method"], str)
assert "_extra_config" not in file_dict
file_obj = File.model_validate(file_dict)
assert file_obj.id == file.id
assert file_obj.tenant_id == file.tenant_id
assert file_obj.type == file.type
assert file_obj.transfer_method == file.transfer_method
assert file_obj.remote_url == file.remote_url
def test_file_to_dict():
file = File(
id="file1",
@@ -36,10 +11,11 @@ def test_file_to_dict():
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg",
storage_key="storage_key",
)
file_dict = file.to_dict()
assert "_extra_config" not in file_dict
assert "_storage_key" not in file_dict
assert "url" in file_dict

View File

@@ -51,6 +51,7 @@ def test_http_request_node_binary_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)
@@ -138,6 +139,7 @@ def test_http_request_node_form_with_file(monkeypatch):
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1111",
storage_key="",
),
),
)

View File

@@ -21,7 +21,8 @@ from core.model_runtime.entities.message_entities import (
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute
@@ -157,6 +158,7 @@ def test_fetch_files_with_file_segment(llm_node):
filename="test.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
)
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file)
@@ -173,6 +175,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test1.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
storage_key="",
),
File(
id="2",
@@ -181,6 +184,7 @@ def test_fetch_files_with_array_file_segment(llm_node):
filename="test2.jpg",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="2",
storage_key="",
),
]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
@@ -224,14 +228,15 @@ def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config):
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
storage_key="",
)
]
fake_query = faker.sentence()
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=fake_query,
user_files=files,
sys_query=fake_query,
sys_files=files,
context=None,
memory=None,
model_config=model_config,
@@ -283,8 +288,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
test_scenarios = [
LLMNodeTestScenario(
description="No files",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
features=[],
vision_enabled=False,
vision_detail=None,
@@ -318,8 +323,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
LLMNodeTestScenario(
description="User files",
user_query=fake_query,
user_files=[
sys_query=fake_query,
sys_files=[
File(
tenant_id="test",
type=FileType.IMAGE,
@@ -328,6 +333,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
],
vision_enabled=True,
@@ -370,8 +376,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
sys_query=fake_query,
sys_files=[],
vision_enabled=False,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
@@ -403,6 +409,7 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
remote_url=fake_remote_url,
extension=".jpg",
mime_type="image/jpg",
storage_key="",
)
},
),
@@ -417,8 +424,8 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
user_files=scenario.user_files,
sys_query=scenario.sys_query,
sys_files=scenario.sys_files,
context=fake_context,
memory=memory,
model_config=model_config,
@@ -435,3 +442,29 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
assert (
prompt_messages == scenario.expected_messages
), f"Message content mismatch in scenario: {scenario.description}"
def test_handle_list_messages_basic(llm_node):
messages = [
LLMNodeChatModelMessage(
text="Hello, {#context#}",
role=PromptMessageRole.USER,
edition_type="basic",
)
]
context = "world"
jinja2_variables = []
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
)
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]

View File

@@ -12,8 +12,8 @@ class LLMNodeTestScenario(BaseModel):
"""Test scenario for LLM node testing."""
description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
sys_query: str = Field(..., description="User query input")
sys_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")

View File

@@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import (
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunStreamChunkEvent,
)
@@ -14,7 +13,9 @@ from models.workflow import WorkflowType
class ContinueOnErrorTestHelper:
@staticmethod
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
def get_code_node(
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
):
"""Helper method to create a code node configuration"""
node = {
"id": "node",
@@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
"code_language": "python3",
"code": "\n".join([line[4:] for line in code.split("\n")]),
"type": "code",
**retry_config,
},
}
if default_value:
@@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
@staticmethod
def get_http_node(
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
error_strategy: str = "fail-branch",
default_value: dict | None = None,
authorization_success: bool = False,
retry_config: dict = {},
):
"""Helper method to create a http node configuration"""
authorization = (
@@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
"body": None,
"type": "http-request",
"error_strategy": error_strategy,
**retry_config,
},
}
if default_value:

View File

@@ -248,6 +248,7 @@ def test_array_file_contains_file_name():
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="1",
filename="ab",
storage_key="",
),
],
)

View File

@@ -57,6 +57,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related1",
storage_key="",
),
File(
filename="document1.pdf",
@@ -64,6 +65,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related2",
storage_key="",
),
File(
filename="image2.png",
@@ -71,6 +73,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related3",
storage_key="",
),
File(
filename="audio1.mp3",
@@ -78,6 +81,7 @@ def test_filter_files_by_type(list_operator_node):
tenant_id="tenant1",
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="related4",
storage_key="",
),
]
variable = ArrayFileSegment(value=files)
@@ -130,6 +134,7 @@ def test_get_file_extract_string_func():
mime_type="text/plain",
remote_url="https://example.com/test_file.txt",
related_id="test_related_id",
storage_key="",
)
# Test each case
@@ -150,6 +155,7 @@ def test_get_file_extract_string_func():
mime_type=None,
remote_url=None,
related_id="test_related_id",
storage_key="",
)
assert _get_file_extract_string_func(key="name")(empty_file) == ""

View File

@@ -0,0 +1,73 @@
from core.workflow.graph_engine.entities.event import (
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunSucceededEvent,
NodeRunRetryEvent,
)
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
DEFAULT_VALUE_EDGE = [
{
"id": "start-source-node-target",
"source": "start",
"target": "node",
"sourceHandle": "source",
},
{
"id": "node-source-answer-target",
"source": "node",
"target": "answer",
"sourceHandle": "source",
},
]
def test_retry_default_value_partial_success():
"""retry default value node with partial success status"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
"default-value",
[{"key": "result", "type": "string", "value": "http node got error response"}],
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert events[-1].outputs == {"answer": "http node got error response"}
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
assert len(events) == 11
def test_retry_failed():
"""retry failed with success status"""
error_code = """
def main() -> dict:
return {
"result": 1 / 0,
}
"""
graph_config = {
"edges": DEFAULT_VALUE_EDGE,
"nodes": [
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
ContinueOnErrorTestHelper.get_http_node(
None,
None,
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
),
],
}
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
events = list(graph_engine.run())
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
assert len(events) == 8

View File

@@ -19,6 +19,7 @@ def file():
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="",
)

View File

@@ -0,0 +1,111 @@
import yaml # type: ignore
from dotenv import dotenv_values
from pathlib import Path
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
"APP_MAX_EXECUTION_TIME",
"BATCH_UPLOAD_LIMIT",
"CELERY_BEAT_SCHEDULER_TIME",
"CODE_EXECUTION_API_KEY",
"HTTP_REQUEST_MAX_CONNECT_TIMEOUT",
"HTTP_REQUEST_MAX_READ_TIMEOUT",
"HTTP_REQUEST_MAX_WRITE_TIMEOUT",
"KEYWORD_DATA_SOURCE_TYPE",
"LOGIN_LOCKOUT_DURATION",
"LOG_FORMAT",
"OCI_ACCESS_KEY",
"OCI_BUCKET_NAME",
"OCI_ENDPOINT",
"OCI_REGION",
"OCI_SECRET_KEY",
"REDIS_DB",
"RESEND_API_URL",
"RESPECT_XFORWARD_HEADERS_ENABLED",
"SENTRY_DSN",
"SSRF_DEFAULT_CONNECT_TIME_OUT",
"SSRF_DEFAULT_MAX_RETRIES",
"SSRF_DEFAULT_READ_TIME_OUT",
"SSRF_DEFAULT_TIME_OUT",
"SSRF_DEFAULT_WRITE_TIME_OUT",
"UPSTASH_VECTOR_TOKEN",
"UPSTASH_VECTOR_URL",
"USING_UGC_INDEX",
"WEAVIATE_BATCH_SIZE",
"WEAVIATE_GRPC_ENABLED",
}
BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = {
"BATCH_UPLOAD_LIMIT",
"CELERY_BEAT_SCHEDULER_TIME",
"HTTP_REQUEST_MAX_CONNECT_TIMEOUT",
"HTTP_REQUEST_MAX_READ_TIMEOUT",
"HTTP_REQUEST_MAX_WRITE_TIMEOUT",
"KEYWORD_DATA_SOURCE_TYPE",
"LOGIN_LOCKOUT_DURATION",
"LOG_FORMAT",
"OPENDAL_FS_ROOT",
"OPENDAL_S3_ACCESS_KEY_ID",
"OPENDAL_S3_BUCKET",
"OPENDAL_S3_ENDPOINT",
"OPENDAL_S3_REGION",
"OPENDAL_S3_ROOT",
"OPENDAL_S3_SECRET_ACCESS_KEY",
"OPENDAL_S3_SERVER_SIDE_ENCRYPTION",
"PGVECTOR_MAX_CONNECTION",
"PGVECTOR_MIN_CONNECTION",
"PGVECTO_RS_DATABASE",
"PGVECTO_RS_HOST",
"PGVECTO_RS_PASSWORD",
"PGVECTO_RS_PORT",
"PGVECTO_RS_USER",
"RESPECT_XFORWARD_HEADERS_ENABLED",
"SCARF_NO_ANALYTICS",
"SSRF_DEFAULT_CONNECT_TIME_OUT",
"SSRF_DEFAULT_MAX_RETRIES",
"SSRF_DEFAULT_READ_TIME_OUT",
"SSRF_DEFAULT_TIME_OUT",
"SSRF_DEFAULT_WRITE_TIME_OUT",
"STORAGE_OPENDAL_SCHEME",
"SUPABASE_API_KEY",
"SUPABASE_BUCKET_NAME",
"SUPABASE_URL",
"USING_UGC_INDEX",
"VIKINGDB_CONNECTION_TIMEOUT",
"VIKINGDB_SOCKET_TIMEOUT",
"WEAVIATE_BATCH_SIZE",
"WEAVIATE_GRPC_ENABLED",
}
API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys())
DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys())
DOCKER_COMPOSE_CONFIG_SET = set()
with open(Path("docker") / Path("docker-compose.yaml")) as f:
DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys())
def test_yaml_config():
# python set == operator is used to compare two sets
DIFF_API_WITH_DOCKER = (
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER:
print(
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
)
raise Exception("API and Docker config sets are different")
DIFF_API_WITH_DOCKER_COMPOSE = (
API_CONFIG_SET
- DOCKER_COMPOSE_CONFIG_SET
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
)
if DIFF_API_WITH_DOCKER_COMPOSE:
print(
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
)
raise Exception("API and Docker Compose config sets are different")
print("All tests passed!")
if __name__ == "__main__":
test_yaml_config()

View File

@@ -107,6 +107,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
APP_MAX_ACTIVE_REQUESTS=0
APP_MAX_EXECUTION_TIME=1200
# ------------------------------
# Container Startup Related Configuration
@@ -606,6 +607,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
# Sentry Configuration
# Used for application monitoring and error log tracking.
# ------------------------------
SENTRY_DSN=
# API Service Sentry DSN address, default is empty, when empty,
# all monitoring information is not reported to Sentry.
@@ -697,6 +699,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500
WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
WORKFLOW_PARALLEL_DEPTH_LIMIT=3
WORKFLOW_FILE_UPLOAD_LIMIT=10
# HTTP request node in workflow configuration
@@ -919,7 +922,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Proxy
HTTP_PROXY=
HTTPS_PROXY=

View File

@@ -28,6 +28,7 @@ x-shared-env: &shared-api-worker-env
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
DIFY_PORT: ${DIFY_PORT:-5001}
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
@@ -258,6 +259,7 @@ x-shared-env: &shared-api-worker-env
UPLOAD_IMAGE_FILE_SIZE_LIMIT: ${UPLOAD_IMAGE_FILE_SIZE_LIMIT:-10}
UPLOAD_VIDEO_FILE_SIZE_LIMIT: ${UPLOAD_VIDEO_FILE_SIZE_LIMIT:-100}
UPLOAD_AUDIO_FILE_SIZE_LIMIT: ${UPLOAD_AUDIO_FILE_SIZE_LIMIT:-50}
SENTRY_DSN: ${SENTRY_DSN:-}
API_SENTRY_DSN: ${API_SENTRY_DSN:-}
API_SENTRY_TRACES_SAMPLE_RATE: ${API_SENTRY_TRACES_SAMPLE_RATE:-1.0}
API_SENTRY_PROFILES_SAMPLE_RATE: ${API_SENTRY_PROFILES_SAMPLE_RATE:-1.0}
@@ -297,6 +299,7 @@ x-shared-env: &shared-api-worker-env
WORKFLOW_MAX_EXECUTION_TIME: ${WORKFLOW_MAX_EXECUTION_TIME:-1200}
WORKFLOW_CALL_MAX_DEPTH: ${WORKFLOW_CALL_MAX_DEPTH:-5}
MAX_VARIABLE_SIZE: ${MAX_VARIABLE_SIZE:-204800}
WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3}
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760}
HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576}
@@ -383,8 +386,6 @@ x-shared-env: &shared-api-worker-env
CSP_WHITELIST: ${CSP_WHITELIST:-}
CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false}
MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100}
HTTP_PROXY: ${HTTP_PROXY:-}
HTTPS_PROXY: ${HTTPS_PROXY:-}
services:
# API service

View File

@@ -16,6 +16,7 @@ import { createContext, useContext } from 'use-context-selector'
import { useShallow } from 'zustand/react/shallow'
import { useTranslation } from 'react-i18next'
import type { ChatItemInTree } from '../../base/chat/types'
import Indicator from '../../header/indicator'
import VarPanel from './var-panel'
import type { FeedbackFunc, FeedbackType, IChatItem, SubmitAnnotationFunc } from '@/app/components/base/chat/chat/type'
import type { Annotation, ChatConversationGeneralDetail, ChatConversationsResponse, ChatMessage, ChatMessagesRequest, CompletionConversationGeneralDetail, CompletionConversationsResponse, LogAnnotation } from '@/models/log'
@@ -57,6 +58,12 @@ type IDrawerContext = {
appDetail?: App
}
type StatusCount = {
success: number
failed: number
partial_success: number
}
const DrawerContext = createContext<IDrawerContext>({} as IDrawerContext)
/**
@@ -71,6 +78,33 @@ const HandThumbIconWithCount: FC<{ count: number; iconType: 'up' | 'down' }> = (
</div>
}
const statusTdRender = (statusCount: StatusCount) => {
if (statusCount.partial_success + statusCount.failed === 0) {
return (
<div className='inline-flex items-center gap-1 system-xs-semibold-uppercase'>
<Indicator color={'green'} />
<span className='text-util-colors-green-green-600'>Success</span>
</div>
)
}
else if (statusCount.failed === 0) {
return (
<div className='inline-flex items-center gap-1 system-xs-semibold-uppercase'>
<Indicator color={'green'} />
<span className='text-util-colors-green-green-600'>Partial Success</span>
</div>
)
}
else {
return (
<div className='inline-flex items-center gap-1 system-xs-semibold-uppercase'>
<Indicator color={'red'} />
<span className='text-util-colors-red-red-600'>{statusCount.failed} {`${statusCount.failed > 1 ? 'Failures' : 'Failure'}`}</span>
</div>
)
}
}
const getFormattedChatList = (messages: ChatMessage[], conversationId: string, timezone: string, format: string) => {
const newChatList: IChatItem[] = []
messages.forEach((item: ChatMessage) => {
@@ -496,8 +530,8 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
}
/**
* Text App Conversation Detail Component
*/
* Text App Conversation Detail Component
*/
const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => {
// Text Generator App Session Details Including Message List
const detailParams = ({ url: `/apps/${appId}/completion-conversations/${conversationId}` })
@@ -542,8 +576,8 @@ const CompletionConversationDetailComp: FC<{ appId?: string; conversationId?: st
}
/**
* Chat App Conversation Detail Component
*/
* Chat App Conversation Detail Component
*/
const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }> = ({ appId, conversationId }) => {
const detailParams = { url: `/apps/${appId}/chat-conversations/${conversationId}` }
const { data: conversationDetail } = useSWR(() => (appId && conversationId) ? detailParams : null, fetchChatConversationDetail)
@@ -585,8 +619,8 @@ const ChatConversationDetailComp: FC<{ appId?: string; conversationId?: string }
}
/**
* Conversation list component including basic information
*/
* Conversation list component including basic information
*/
const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh }) => {
const { t } = useTranslation()
const { formatTime } = useTimestamp()
@@ -597,6 +631,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
const [showDrawer, setShowDrawer] = useState<boolean>(false) // Whether to display the chat details drawer
const [currentConversation, setCurrentConversation] = useState<ChatConversationGeneralDetail | CompletionConversationGeneralDetail | undefined>() // Currently selected conversation
const isChatMode = appDetail.mode !== 'completion' // Whether the app is a chat app
const isChatflow = appDetail.mode === 'advanced-chat' // Whether the app is a chatflow app
const { setShowPromptLogModal, setShowAgentLogModal } = useAppStore(useShallow(state => ({
setShowPromptLogModal: state.setShowPromptLogModal,
setShowAgentLogModal: state.setShowAgentLogModal,
@@ -639,6 +674,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
<td className='pl-2 pr-1 w-5 rounded-l-lg bg-background-section-burn whitespace-nowrap'></td>
<td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{isChatMode ? t('appLog.table.header.summary') : t('appLog.table.header.input')}</td>
<td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{t('appLog.table.header.endUser')}</td>
{isChatflow && <td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{t('appLog.table.header.status')}</td>}
<td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{isChatMode ? t('appLog.table.header.messageCount') : t('appLog.table.header.output')}</td>
<td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{t('appLog.table.header.userRate')}</td>
<td className='pl-3 py-1.5 bg-background-section-burn whitespace-nowrap'>{t('appLog.table.header.adminRate')}</td>
@@ -669,6 +705,9 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
{renderTdValue(leftValue || t('appLog.table.empty.noChat'), !leftValue, isChatMode && log.annotated)}
</td>
<td className='p-3 pr-2'>{renderTdValue(endUser || defaultValue, !endUser)}</td>
{isChatflow && <td className='p-3 pr-2 w-[160px]' style={{ maxWidth: isChatMode ? 300 : 200 }}>
{statusTdRender(log.status_count)}
</td>}
<td className='p-3 pr-2' style={{ maxWidth: isChatMode ? 100 : 200 }}>
{renderTdValue(rightValue === 0 ? 0 : (rightValue || t('appLog.table.empty.noOutput')), !rightValue, !isChatMode && !!log.annotation?.content, log.annotation)}
</td>

View File

@@ -63,6 +63,14 @@ const WorkflowAppLogList: FC<ILogs> = ({ logs, appDetail, onRefresh }) => {
</div>
)
}
if (status === 'partial-succeeded') {
return (
<div className='inline-flex items-center gap-1 system-xs-semibold-uppercase'>
<Indicator color={'green'} />
<span className='text-util-colors-green-green-600'>Partial Success</span>
</div>
)
}
}
const onCloseDrawer = () => {

View File

@@ -64,6 +64,12 @@ const WorkflowProcessItem = ({
setShowMessageLogModal(true)
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
const showRetryDetail = useCallback(() => {
setCurrentLogItem(item)
setCurrentLogModalActiveTab('TRACING')
setShowMessageLogModal(true)
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
return (
<div
className={cn(
@@ -105,6 +111,7 @@ const WorkflowProcessItem = ({
<TracingPanel
list={data.tracing}
onShowIterationDetail={showIterationDetail}
onShowRetryDetail={showRetryDetail}
hideNodeInfo={hideInfo}
hideNodeProcessDetail={hideProcessDetail}
/>

Some files were not shown because too many files have changed in this diff Show More