mirror of
https://github.com/langgenius/dify.git
synced 2025-12-27 01:27:24 +00:00
Compare commits
23 Commits
0.14.1
...
feat/node-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fde3fe0ab6 | ||
|
|
07528f82b9 | ||
|
|
bb2f46d7cc | ||
|
|
463fbe2680 | ||
|
|
95a7e50137 | ||
|
|
9d93ad1f16 | ||
|
|
44104797d6 | ||
|
|
1548501050 | ||
|
|
de3911e930 | ||
|
|
5a8a901560 | ||
|
|
12d45e9114 | ||
|
|
d057067543 | ||
|
|
560d375e0f | ||
|
|
127291a90f | ||
|
|
9e0c28791d | ||
|
|
3388d6636c | ||
|
|
2624a6dcd0 | ||
|
|
b5c2785e10 | ||
|
|
493834d45d | ||
|
|
b411087bb7 | ||
|
|
357769c72e | ||
|
|
853b9af09c | ||
|
|
b99f1a09f4 |
3
.github/workflows/api-tests.yml
vendored
3
.github/workflows/api-tests.yml
vendored
@@ -50,6 +50,9 @@ jobs:
|
||||
- name: Run ModelRuntime
|
||||
run: poetry run -C api bash dev/pytest/pytest_model_runtime.sh
|
||||
|
||||
- name: Run dify config tests
|
||||
run: poetry run -C api python dev/pytest/pytest_config_tests.py
|
||||
|
||||
- name: Run Tool
|
||||
run: poetry run -C api bash dev/pytest/pytest_tools.sh
|
||||
|
||||
|
||||
@@ -70,7 +70,6 @@ ignore = [
|
||||
"SIM113", # eumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
"SIM210", # if-expr-with-true-false
|
||||
"SIM300", # yoda-conditions,
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
|
||||
@@ -31,7 +31,7 @@ def admin_required(view):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if dify_config.ADMIN_API_KEY != auth_token:
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -13,6 +13,7 @@ app_fields = {
|
||||
"name": fields.String,
|
||||
"mode": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"icon_background": fields.String,
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -328,6 +329,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
|
||||
@@ -18,6 +18,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -286,9 +287,25 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if node_failed_response:
|
||||
yield node_failed_response
|
||||
elif isinstance(
|
||||
event,
|
||||
QueueNodeRetryEvent,
|
||||
):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_retried(
|
||||
workflow_run=workflow_run, event=event
|
||||
)
|
||||
|
||||
response = self._workflow_node_retry_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
)
|
||||
|
||||
if response:
|
||||
yield response
|
||||
|
||||
elif isinstance(event, QueueParallelBranchRunStartedEvent):
|
||||
if not workflow_run:
|
||||
raise Exception("Workflow run not initialized.")
|
||||
|
||||
@@ -11,6 +11,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -38,6 +39,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -420,6 +422,36 @@ class WorkflowBasedAppRunner(AppRunner):
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
self._publish_event(
|
||||
QueueNodeRetryEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
error=event.error,
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
start_index=event.start_index,
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
|
||||
@@ -43,6 +43,7 @@ class QueueEvent(StrEnum):
|
||||
ERROR = "error"
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -313,6 +314,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
iteration_duration_map: Optional[dict[str, float]] = None
|
||||
|
||||
|
||||
class QueueNodeRetryEvent(AppQueueEvent):
|
||||
"""QueueNodeRetryEvent entity"""
|
||||
|
||||
event: QueueEvent = QueueEvent.RETRY
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
retry_index: int # retry index
|
||||
start_index: int # start index
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
|
||||
@@ -52,6 +52,7 @@ class StreamEvent(Enum):
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
NODE_RETRY = "node_retry"
|
||||
PARALLEL_BRANCH_STARTED = "parallel_branch_started"
|
||||
PARALLEL_BRANCH_FINISHED = "parallel_branch_finished"
|
||||
ITERATION_STARTED = "iteration_started"
|
||||
@@ -342,6 +343,75 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
}
|
||||
|
||||
|
||||
class NodeRetryStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeFinishStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[dict] = None
|
||||
process_data: Optional[dict] = None
|
||||
outputs: Optional[dict] = None
|
||||
status: str
|
||||
error: Optional[str] = None
|
||||
elapsed_time: float
|
||||
execution_metadata: Optional[dict] = None
|
||||
created_at: int
|
||||
finished_at: int
|
||||
files: Optional[Sequence[Mapping[str, Any]]] = []
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
def to_ignore_detail_dict(self):
|
||||
return {
|
||||
"event": self.event.value,
|
||||
"task_id": self.task_id,
|
||||
"workflow_run_id": self.workflow_run_id,
|
||||
"data": {
|
||||
"id": self.data.id,
|
||||
"node_id": self.data.node_id,
|
||||
"node_type": self.data.node_type,
|
||||
"title": self.data.title,
|
||||
"index": self.data.index,
|
||||
"predecessor_node_id": self.data.predecessor_node_id,
|
||||
"inputs": None,
|
||||
"process_data": None,
|
||||
"outputs": None,
|
||||
"status": self.data.status,
|
||||
"error": None,
|
||||
"elapsed_time": self.data.elapsed_time,
|
||||
"execution_metadata": None,
|
||||
"created_at": self.data.created_at,
|
||||
"finished_at": self.data.finished_at,
|
||||
"files": [],
|
||||
"parallel_id": self.data.parallel_id,
|
||||
"parallel_start_node_id": self.data.parallel_start_node_id,
|
||||
"parent_parallel_id": self.data.parent_parallel_id,
|
||||
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParallelBranchStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
ParallelBranchStartStreamResponse entity
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
@@ -26,6 +27,7 @@ from core.app.entities.task_entities import (
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
@@ -423,6 +425,52 @@ class WorkflowCycleManage:
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.index = event.start_index
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
@@ -587,6 +635,51 @@ class WorkflowCycleManage:
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param event: queue node succeeded or failed event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
|
||||
@@ -45,7 +45,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
)
|
||||
|
||||
retries = 0
|
||||
stream = kwargs.pop("stream", False)
|
||||
while retries <= max_retries:
|
||||
try:
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
|
||||
@@ -819,6 +819,82 @@ LLM_BASE_MODELS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4o-2024-11-20",
|
||||
entity=AIModelEntity(
|
||||
model="fake-deployment-name",
|
||||
label=I18nObject(
|
||||
en_US="fake-deployment-name-label",
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[
|
||||
ModelFeature.AGENT_THOUGHT,
|
||||
ModelFeature.VISION,
|
||||
ModelFeature.MULTI_TOOL_CALL,
|
||||
ModelFeature.STREAM_TOOL_CALL,
|
||||
],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
ModelPropertyKey.CONTEXT_SIZE: 128000,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name="presence_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name="frequency_penalty",
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=16384),
|
||||
ParameterRule(
|
||||
name="seed",
|
||||
label=I18nObject(zh_Hans="种子", en_US="Seed"),
|
||||
type="int",
|
||||
help=AZURE_DEFAULT_PARAM_SEED_HELP,
|
||||
required=False,
|
||||
precision=2,
|
||||
min=0,
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
label=I18nObject(zh_Hans="回复格式", en_US="response_format"),
|
||||
type="string",
|
||||
help=I18nObject(
|
||||
zh_Hans="指定模型必须输出的格式", en_US="specifying the format that the model must output"
|
||||
),
|
||||
required=False,
|
||||
options=["text", "json_object", "json_schema"],
|
||||
),
|
||||
ParameterRule(
|
||||
name="json_schema",
|
||||
label=I18nObject(en_US="JSON Schema"),
|
||||
type="text",
|
||||
help=I18nObject(
|
||||
zh_Hans="设置返回的json schema,llm将按照它返回",
|
||||
en_US="Set a response json schema will ensure LLM to adhere it.",
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=5.00,
|
||||
output=15.00,
|
||||
unit=0.000001,
|
||||
currency="USD",
|
||||
),
|
||||
),
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name="gpt-4-turbo",
|
||||
entity=AIModelEntity(
|
||||
|
||||
@@ -171,6 +171,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4o-2024-11-20
|
||||
value: gpt-4o-2024-11-20
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: gpt-4-turbo
|
||||
value: gpt-4-turbo
|
||||
|
||||
@@ -92,7 +92,10 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import boto3
|
||||
from botocore.config import Config
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
|
||||
|
||||
def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
|
||||
region_name = credentials.get("aws_region")
|
||||
if not region_name:
|
||||
raise InvokeBadRequestError("aws_region is required")
|
||||
client_config = Config(region_name=region_name)
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
|
||||
def get_bedrock_client(service_name, credentials=None):
|
||||
client_config = Config(region_name=credentials["aws_region"])
|
||||
aws_access_key_id = credentials["aws_access_key_id"]
|
||||
aws_secret_access_key = credentials["aws_secret_access_key"]
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
# use aksk to call bedrock
|
||||
client = boto3.client(
|
||||
|
||||
@@ -62,7 +62,10 @@ class BedrockRerankModel(RerankModel):
|
||||
}
|
||||
)
|
||||
modelId = model
|
||||
region = credentials["aws_region"]
|
||||
region = credentials.get("aws_region")
|
||||
# region is a required field
|
||||
if not region:
|
||||
raise InvokeBadRequestError("aws_region is required in credentials")
|
||||
model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
|
||||
rerankingConfiguration = {
|
||||
"type": "BEDROCK_RERANKING_MODEL",
|
||||
|
||||
@@ -88,7 +88,10 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-1.5-pro
|
||||
- gemini-1.5-pro-latest
|
||||
- gemini-1.5-pro-001
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@@ -97,7 +97,10 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
@@ -119,7 +119,7 @@ class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel):
|
||||
embeddings.append(result[0].get("embedding"))
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
elif "texts" == text_input_key:
|
||||
elif text_input_key == "texts":
|
||||
result = client.run(
|
||||
replicate_model_version,
|
||||
input={
|
||||
|
||||
@@ -18,7 +18,7 @@ class SiliconflowProvider(ModelProvider):
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials=credentials)
|
||||
model_instance.validate_credentials(model="deepseek-ai/DeepSeek-V2.5", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
||||
@@ -100,7 +100,10 @@ class UpstageTextEmbeddingModel(_CommonUpstage, TextEmbeddingModel):
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
embedding = (average / np.linalg.norm(average)).tolist()
|
||||
if np.isnan(embedding).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
embeddings[i] = embedding
|
||||
|
||||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
|
||||
|
||||
|
||||
@@ -40,6 +40,10 @@ configs: dict[str, ModelConfig] = {
|
||||
properties=ModelProperties(context_size=32768, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
),
|
||||
"Doubao-pro-256k": ModelConfig(
|
||||
properties=ModelProperties(context_size=262144, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[],
|
||||
),
|
||||
"Doubao-pro-128k": ModelConfig(
|
||||
properties=ModelProperties(context_size=131072, max_tokens=4096, mode=LLMMode.CHAT),
|
||||
features=[ModelFeature.TOOL_CALL],
|
||||
|
||||
@@ -12,6 +12,7 @@ class ModelConfig(BaseModel):
|
||||
|
||||
ModelConfigs = {
|
||||
"Doubao-embedding": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
"Doubao-embedding-large": ModelConfig(properties=ModelProperties(context_size=4096, max_chunks=32)),
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig:
|
||||
if not model_configs:
|
||||
return ModelConfig(
|
||||
properties=ModelProperties(
|
||||
context_size=int(credentials.get("context_size", 0)),
|
||||
context_size=int(credentials.get("context_size", 4096)),
|
||||
max_chunks=int(credentials.get("max_chunks", 1)),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -166,6 +166,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Doubao-pro-256k
|
||||
value: Doubao-pro-256k
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- label:
|
||||
en_US: Llama3-8B
|
||||
value: Llama3-8B
|
||||
@@ -220,6 +226,12 @@ model_credential_schema:
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Doubao-embedding-large
|
||||
value: Doubao-embedding-large
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: text-embedding
|
||||
- label:
|
||||
en_US: Custom
|
||||
zh_Hans: 自定义
|
||||
|
||||
@@ -65,6 +65,11 @@ class CacheEmbedding(Embeddings):
|
||||
for vector in embedding_result.embeddings:
|
||||
try:
|
||||
normalized_embedding = (vector / np.linalg.norm(vector)).tolist()
|
||||
# stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan
|
||||
if np.isnan(normalized_embedding).any():
|
||||
# for issue #11827 float values are not json compliant
|
||||
logger.warning(f"Normalized embedding is nan: {normalized_embedding}")
|
||||
continue
|
||||
embedding_queue_embeddings.append(normalized_embedding)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
@@ -111,6 +116,8 @@ class CacheEmbedding(Embeddings):
|
||||
|
||||
embedding_results = embedding_result.embeddings[0]
|
||||
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
|
||||
if np.isnan(embedding_results).any():
|
||||
raise ValueError("Normalized embedding is nan please try again")
|
||||
except Exception as ex:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
|
||||
|
||||
@@ -11,7 +11,10 @@ class ComfyUIProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
ws = websocket.WebSocket()
|
||||
base_url = URL(credentials.get("base_url"))
|
||||
ws_address = f"ws://{base_url.authority}/ws?clientId=test123"
|
||||
ws_protocol = "ws"
|
||||
if base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{base_url.authority}/ws?clientId=test123"
|
||||
|
||||
try:
|
||||
ws.connect(ws_address)
|
||||
|
||||
@@ -40,7 +40,10 @@ class ComfyUiClient:
|
||||
def open_websocket_connection(self) -> tuple[WebSocket, str]:
|
||||
client_id = str(uuid.uuid4())
|
||||
ws = WebSocket()
|
||||
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws_protocol = "ws"
|
||||
if self.base_url.scheme == "https":
|
||||
ws_protocol = "wss"
|
||||
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
|
||||
ws.connect(ws_address)
|
||||
return ws, client_id
|
||||
|
||||
|
||||
@@ -45,3 +45,6 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
error: Optional[str] = None # error message if status is failed
|
||||
error_type: Optional[str] = None # error type if status is failed
|
||||
|
||||
# single step node run retry
|
||||
retry_index: int = 0
|
||||
|
||||
@@ -97,6 +97,13 @@ class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
start_at: datetime = Field(..., description="retry start time")
|
||||
start_index: int = Field(..., description="retry start index")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -25,6 +26,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
@@ -581,7 +583,7 @@ class GraphEngine:
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
node_instance: BaseNode[BaseNodeData],
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
@@ -607,36 +609,121 @@ class GraphEngine:
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
max_retries = node_instance.node_data.retry_config.max_retries
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
retries = 0
|
||||
shoudl_continue_retry = True
|
||||
while shoudl_continue_retry and retries <= max_retries:
|
||||
try:
|
||||
# run node
|
||||
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if node_instance.should_retry and retries < max_retries:
|
||||
retries += 1
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
route_node_state.node_run_result = run_result
|
||||
yield NodeRunRetryEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
error=run_result.error,
|
||||
retry_index=retries,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
start_at=retry_start_at,
|
||||
start_index=self.graph_runtime_state.node_run_steps,
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
continue
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if node_instance.should_continue_on_error:
|
||||
# if run failed, handle error
|
||||
run_result = self._handle_continue_on_error(
|
||||
node_instance,
|
||||
item.run_result,
|
||||
self.graph_runtime_state.variable_pool,
|
||||
handle_exceptions=handle_exceptions,
|
||||
)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.status = RouteNodeState.Status.EXCEPTION
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
@@ -645,21 +732,23 @@ class GraphEngine:
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
yield NodeRunExceptionEvent(
|
||||
error=run_result.error or "System Error",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or "Unknown error.",
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = (
|
||||
parallel_start_node_id
|
||||
)
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
@@ -670,108 +759,59 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
shoudl_continue_retry = False
|
||||
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if node_instance.should_continue_on_error and self.graph.edge_mapping.get(
|
||||
node_instance.node_id
|
||||
):
|
||||
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value,
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = (
|
||||
parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||
"""
|
||||
|
||||
@@ -106,12 +106,25 @@ class DefaultValue(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
class RetryConfig(BaseModel):
|
||||
"""node retry config"""
|
||||
|
||||
max_retries: int = 0 # max retry times
|
||||
retry_interval: int = 0 # retry interval in milliseconds
|
||||
retry_enabled: bool = False # whether retry is enabled
|
||||
|
||||
@property
|
||||
def retry_interval_seconds(self) -> float:
|
||||
return self.retry_interval / 1000
|
||||
|
||||
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
error_strategy: Optional[ErrorStrategy] = None
|
||||
default_value: Optional[list[DefaultValue]] = None
|
||||
version: str = "1"
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
@property
|
||||
def default_value_dict(self):
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
@@ -147,3 +147,12 @@ class BaseNode(Generic[GenericNodeData]):
|
||||
bool: if should continue on error
|
||||
"""
|
||||
return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
|
||||
|
||||
@property
|
||||
def should_retry(self) -> bool:
|
||||
"""judge if should retry
|
||||
|
||||
Returns:
|
||||
bool: if should retry
|
||||
"""
|
||||
return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
|
||||
|
||||
@@ -35,3 +35,4 @@ class FailBranchSourceHandle(StrEnum):
|
||||
|
||||
|
||||
CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
RETRY_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.TOOL, NodeType.HTTP_REQUEST]
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from .event import (
|
||||
ModelInvokeCompletedEvent,
|
||||
RunCompletedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
RunStreamChunkEvent,
|
||||
)
|
||||
from .types import NodeEvent
|
||||
|
||||
__all__ = [
|
||||
@@ -6,5 +12,6 @@ __all__ = [
|
||||
"NodeEvent",
|
||||
"RunCompletedEvent",
|
||||
"RunRetrieverResourceEvent",
|
||||
"RunRetryEvent",
|
||||
"RunStreamChunkEvent",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
@@ -26,3 +29,25 @@ class ModelInvokeCompletedEvent(BaseModel):
|
||||
text: str
|
||||
usage: LLMUsage
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class RunRetryEvent(BaseModel):
|
||||
"""Node Run Retry event"""
|
||||
|
||||
error: str = Field(..., description="error")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class SingleStepRetryEvent(BaseModel):
|
||||
"""Single step retry event"""
|
||||
|
||||
status: str = WorkflowNodeExecutionStatus.RETRY.value
|
||||
|
||||
inputs: dict | None = Field(..., description="input")
|
||||
error: str = Field(..., description="error")
|
||||
outputs: dict = Field(..., description="output")
|
||||
retry_index: int = Field(..., description="Retry attempt number")
|
||||
error: str = Field(..., description="error")
|
||||
elapsed_time: float = Field(..., description="elapsed time")
|
||||
execution_metadata: dict | None = Field(..., description="execution metadata")
|
||||
|
||||
@@ -45,6 +45,7 @@ class Executor:
|
||||
headers: dict[str, str]
|
||||
auth: HttpRequestNodeAuthorization
|
||||
timeout: HttpRequestNodeTimeout
|
||||
max_retries: int
|
||||
|
||||
boundary: str
|
||||
|
||||
@@ -54,6 +55,7 @@ class Executor:
|
||||
node_data: HttpRequestNodeData,
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
):
|
||||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
@@ -73,6 +75,7 @@ class Executor:
|
||||
self.files = None
|
||||
self.data = None
|
||||
self.json = None
|
||||
self.max_retries = max_retries
|
||||
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
@@ -241,6 +244,7 @@ class Executor:
|
||||
"params": self.params,
|
||||
"timeout": (self.timeout.connect, self.timeout.read, self.timeout.write),
|
||||
"follow_redirects": True,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
# request_args = {k: v for k, v in request_args.items() if v is not None}
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
@@ -51,6 +52,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
"max_write_timeout": dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT,
|
||||
},
|
||||
},
|
||||
"retry_config": {
|
||||
"max_retries": dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
"retry_interval": 0.5 * (2**2),
|
||||
"retry_enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
@@ -60,12 +66,13 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
node_data=self.node_data,
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
max_retries=0,
|
||||
)
|
||||
process_data["request"] = http_executor.to_log()
|
||||
|
||||
response = http_executor.invoke()
|
||||
files = self.extract_files(url=http_executor.url, response=response)
|
||||
if not response.response.is_success and self.should_continue_on_error:
|
||||
if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
outputs={
|
||||
@@ -156,20 +163,24 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
|
||||
def extract_files(self, url: str, response: Response) -> list[File]:
|
||||
"""
|
||||
Extract files from response
|
||||
Extract files from response by checking both Content-Type header and URL
|
||||
"""
|
||||
files = []
|
||||
is_file = response.is_file
|
||||
content_type = response.content_type
|
||||
content = response.content
|
||||
|
||||
if is_file and content_type:
|
||||
if is_file:
|
||||
# Guess file extension from URL or Content-Type header
|
||||
filename = url.split("?")[0].split("/")[-1] or ""
|
||||
mime_type = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=content,
|
||||
mimetype=content_type,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
|
||||
mapping = {
|
||||
|
||||
@@ -29,6 +29,7 @@ workflow_run_for_list_fields = {
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"exceptions_count": fields.Integer,
|
||||
"retry_index": fields.Integer,
|
||||
}
|
||||
|
||||
advanced_chat_workflow_run_for_list_fields = {
|
||||
@@ -45,6 +46,7 @@ advanced_chat_workflow_run_for_list_fields = {
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"exceptions_count": fields.Integer,
|
||||
"retry_index": fields.Integer,
|
||||
}
|
||||
|
||||
advanced_chat_workflow_run_pagination_fields = {
|
||||
@@ -79,6 +81,17 @@ workflow_run_detail_fields = {
|
||||
"exceptions_count": fields.Integer,
|
||||
}
|
||||
|
||||
retry_event_field = {
|
||||
"error": fields.String,
|
||||
"retry_index": fields.Integer,
|
||||
"inputs": fields.Raw(attribute="inputs"),
|
||||
"elapsed_time": fields.Float,
|
||||
"execution_metadata": fields.Raw(attribute="execution_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"outputs": fields.Raw(attribute="outputs"),
|
||||
}
|
||||
|
||||
|
||||
workflow_run_node_execution_fields = {
|
||||
"id": fields.String,
|
||||
"index": fields.Integer,
|
||||
@@ -99,6 +112,7 @@ workflow_run_node_execution_fields = {
|
||||
"created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True),
|
||||
"finished_at": TimestampField,
|
||||
"retry_events": fields.List(fields.Nested(retry_event_field)),
|
||||
}
|
||||
|
||||
workflow_run_node_execution_list_fields = {
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add retry_index field to node-execution model
|
||||
|
||||
Revision ID: 348cb0a93d53
|
||||
Revises: cf8f4fc45278
|
||||
Create Date: 2024-12-16 01:23:13.093432
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '348cb0a93d53'
|
||||
down_revision = 'cf8f4fc45278'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('retry_index', sa.Integer(), server_default=sa.text('0'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_column('retry_index')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -529,6 +529,7 @@ class WorkflowNodeExecutionStatus(Enum):
|
||||
SUCCEEDED = "succeeded"
|
||||
FAILED = "failed"
|
||||
EXCEPTION = "exception"
|
||||
RETRY = "retry"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus":
|
||||
@@ -639,6 +640,7 @@ class WorkflowNodeExecution(db.Model):
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
finished_at = db.Column(db.DateTime)
|
||||
retry_index = db.Column(db.Integer, server_default=db.text("0"))
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import ErrorStrategy
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.event.event import SingleStepRetryEvent
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||
@@ -220,56 +221,99 @@ class WorkflowService:
|
||||
|
||||
# run draft workflow node
|
||||
start_at = time.perf_counter()
|
||||
retries = 0
|
||||
max_retries = 0
|
||||
should_retry = True
|
||||
retry_events = []
|
||||
|
||||
try:
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
while retries <= max_retries and should_retry:
|
||||
retry_start_at = time.perf_counter()
|
||||
node_instance, generator = WorkflowEntry.single_step_run(
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
user_id=account.id,
|
||||
)
|
||||
node_instance = cast(BaseNode[BaseNodeData], node_instance)
|
||||
max_retries = (
|
||||
node_instance.node_data.retry_config.max_retries if node_instance.node_data.retry_config else 0
|
||||
)
|
||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||
node_run_result: NodeRunResult | None = None
|
||||
for event in generator:
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
node_run_result = event.run_result
|
||||
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
break
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
|
||||
node_error_args = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||
}
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
run_succeeded = node_run_result.status in (
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
if not node_run_result:
|
||||
raise ValueError("Node run failed with no run result")
|
||||
# single step debug mode error handling return
|
||||
if node_run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
if (
|
||||
retries == max_retries
|
||||
and node_instance.node_type == NodeType.HTTP_REQUEST
|
||||
and node_run_result.outputs
|
||||
and not node_instance.should_continue_on_error
|
||||
):
|
||||
node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
should_retry = False
|
||||
else:
|
||||
if node_instance.should_retry:
|
||||
node_run_result.status = WorkflowNodeExecutionStatus.RETRY
|
||||
retries += 1
|
||||
node_run_result.retry_index = retries
|
||||
retry_events.append(
|
||||
SingleStepRetryEvent(
|
||||
inputs=WorkflowEntry.handle_special_values(node_run_result.inputs)
|
||||
if node_run_result.inputs
|
||||
else None,
|
||||
error=node_run_result.error,
|
||||
outputs=WorkflowEntry.handle_special_values(node_run_result.outputs)
|
||||
if node_run_result.outputs
|
||||
else None,
|
||||
retry_index=node_run_result.retry_index,
|
||||
elapsed_time=time.perf_counter() - retry_start_at,
|
||||
execution_metadata=WorkflowEntry.handle_special_values(node_run_result.metadata)
|
||||
if node_run_result.metadata
|
||||
else None,
|
||||
)
|
||||
)
|
||||
time.sleep(retry_interval)
|
||||
else:
|
||||
should_retry = False
|
||||
if node_instance.should_continue_on_error:
|
||||
node_error_args = {
|
||||
"status": WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
"error": node_run_result.error,
|
||||
"inputs": node_run_result.inputs,
|
||||
"metadata": {"error_strategy": node_instance.node_data.error_strategy},
|
||||
}
|
||||
if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
**node_instance.node_data.default_value_dict,
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
else:
|
||||
node_run_result = NodeRunResult(
|
||||
**node_error_args,
|
||||
outputs={
|
||||
"error_message": node_run_result.error,
|
||||
"error_type": node_run_result.error_type,
|
||||
},
|
||||
)
|
||||
run_succeeded = node_run_result.status in (
|
||||
WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e.node_instance
|
||||
run_succeeded = False
|
||||
@@ -318,6 +362,7 @@ class WorkflowService:
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
workflow_node_execution.retry_events = retry_events
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ class MockXinferenceClass:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if "generate" == model_uid:
|
||||
if model_uid == "generate":
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "chat" == model_uid:
|
||||
if model_uid == "chat":
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "embedding" == model_uid:
|
||||
if model_uid == "embedding":
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if "rerank" == model_uid:
|
||||
if model_uid == "rerank":
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
|
||||
@@ -34,9 +34,9 @@ def test_api_tool(setup_http_mock):
|
||||
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "/p_param" == response.request.url.path
|
||||
assert b"query_param=q_param" == response.request.url.query
|
||||
assert "h_param" == response.request.headers.get("header_param")
|
||||
assert "application/json" == response.request.headers.get("content-type")
|
||||
assert "cookie_param=c_param" == response.request.headers.get("cookie")
|
||||
assert response.request.url.path == "/p_param"
|
||||
assert response.request.url.query == b"query_param=q_param"
|
||||
assert response.request.headers.get("header_param") == "h_param"
|
||||
assert response.request.headers.get("content-type") == "application/json"
|
||||
assert response.request.headers.get("cookie") == "cookie_param=c_param"
|
||||
assert "b_param" in response.content.decode()
|
||||
|
||||
@@ -384,7 +384,7 @@ def test_mock_404(setup_http_mock):
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
|
||||
assert 404 == resp.get("status_code")
|
||||
assert resp.get("status_code") == 404
|
||||
assert "Not Found" in resp.get("body", "")
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
)
|
||||
@@ -14,7 +13,9 @@ from models.workflow import WorkflowType
|
||||
|
||||
class ContinueOnErrorTestHelper:
|
||||
@staticmethod
|
||||
def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None):
|
||||
def get_code_node(
|
||||
code: str, error_strategy: str = "fail-branch", default_value: dict | None = None, retry_config: dict = {}
|
||||
):
|
||||
"""Helper method to create a code node configuration"""
|
||||
node = {
|
||||
"id": "node",
|
||||
@@ -26,6 +27,7 @@ class ContinueOnErrorTestHelper:
|
||||
"code_language": "python3",
|
||||
"code": "\n".join([line[4:] for line in code.split("\n")]),
|
||||
"type": "code",
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
@@ -34,7 +36,10 @@ class ContinueOnErrorTestHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_http_node(
|
||||
error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False
|
||||
error_strategy: str = "fail-branch",
|
||||
default_value: dict | None = None,
|
||||
authorization_success: bool = False,
|
||||
retry_config: dict = {},
|
||||
):
|
||||
"""Helper method to create a http node configuration"""
|
||||
authorization = (
|
||||
@@ -65,6 +70,7 @@ class ContinueOnErrorTestHelper:
|
||||
"body": None,
|
||||
"type": "http-request",
|
||||
"error_strategy": error_strategy,
|
||||
**retry_config,
|
||||
},
|
||||
}
|
||||
if default_value:
|
||||
|
||||
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
73
api/tests/unit_tests/core/workflow/nodes/test_retry.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunRetryEvent,
|
||||
)
|
||||
from tests.unit_tests.core.workflow.nodes.test_continue_on_error import ContinueOnErrorTestHelper
|
||||
|
||||
DEFAULT_VALUE_EDGE = [
|
||||
{
|
||||
"id": "start-source-node-target",
|
||||
"source": "start",
|
||||
"target": "node",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
{
|
||||
"id": "node-source-answer-target",
|
||||
"source": "node",
|
||||
"target": "answer",
|
||||
"sourceHandle": "source",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_retry_default_value_partial_success():
|
||||
"""retry default value node with partial success status"""
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
"default-value",
|
||||
[{"key": "result", "type": "string", "value": "http node got error response"}],
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert events[-1].outputs == {"answer": "http node got error response"}
|
||||
assert any(isinstance(e, GraphRunPartialSucceededEvent) for e in events)
|
||||
assert len(events) == 11
|
||||
|
||||
|
||||
def test_retry_failed():
|
||||
"""retry failed with success status"""
|
||||
error_code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": 1 / 0,
|
||||
}
|
||||
"""
|
||||
|
||||
graph_config = {
|
||||
"edges": DEFAULT_VALUE_EDGE,
|
||||
"nodes": [
|
||||
{"data": {"title": "start", "type": "start", "variables": []}, "id": "start"},
|
||||
{"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"},
|
||||
ContinueOnErrorTestHelper.get_http_node(
|
||||
None,
|
||||
None,
|
||||
retry_config={"retry_config": {"max_retries": 2, "retry_interval": 1000, "retry_enabled": True}},
|
||||
),
|
||||
],
|
||||
}
|
||||
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
|
||||
events = list(graph_engine.run())
|
||||
assert sum(1 for e in events if isinstance(e, NodeRunRetryEvent)) == 2
|
||||
assert any(isinstance(e, GraphRunFailedEvent) for e in events)
|
||||
assert len(events) == 8
|
||||
111
dev/pytest/pytest_config_tests.py
Normal file
111
dev/pytest/pytest_config_tests.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import yaml # type: ignore
|
||||
from dotenv import dotenv_values
|
||||
from pathlib import Path
|
||||
|
||||
BASE_API_AND_DOCKER_CONFIG_SET_DIFF = {
|
||||
"APP_MAX_EXECUTION_TIME",
|
||||
"BATCH_UPLOAD_LIMIT",
|
||||
"CELERY_BEAT_SCHEDULER_TIME",
|
||||
"CODE_EXECUTION_API_KEY",
|
||||
"HTTP_REQUEST_MAX_CONNECT_TIMEOUT",
|
||||
"HTTP_REQUEST_MAX_READ_TIMEOUT",
|
||||
"HTTP_REQUEST_MAX_WRITE_TIMEOUT",
|
||||
"KEYWORD_DATA_SOURCE_TYPE",
|
||||
"LOGIN_LOCKOUT_DURATION",
|
||||
"LOG_FORMAT",
|
||||
"OCI_ACCESS_KEY",
|
||||
"OCI_BUCKET_NAME",
|
||||
"OCI_ENDPOINT",
|
||||
"OCI_REGION",
|
||||
"OCI_SECRET_KEY",
|
||||
"REDIS_DB",
|
||||
"RESEND_API_URL",
|
||||
"RESPECT_XFORWARD_HEADERS_ENABLED",
|
||||
"SENTRY_DSN",
|
||||
"SSRF_DEFAULT_CONNECT_TIME_OUT",
|
||||
"SSRF_DEFAULT_MAX_RETRIES",
|
||||
"SSRF_DEFAULT_READ_TIME_OUT",
|
||||
"SSRF_DEFAULT_TIME_OUT",
|
||||
"SSRF_DEFAULT_WRITE_TIME_OUT",
|
||||
"UPSTASH_VECTOR_TOKEN",
|
||||
"UPSTASH_VECTOR_URL",
|
||||
"USING_UGC_INDEX",
|
||||
"WEAVIATE_BATCH_SIZE",
|
||||
"WEAVIATE_GRPC_ENABLED",
|
||||
}
|
||||
|
||||
BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = {
|
||||
"BATCH_UPLOAD_LIMIT",
|
||||
"CELERY_BEAT_SCHEDULER_TIME",
|
||||
"HTTP_REQUEST_MAX_CONNECT_TIMEOUT",
|
||||
"HTTP_REQUEST_MAX_READ_TIMEOUT",
|
||||
"HTTP_REQUEST_MAX_WRITE_TIMEOUT",
|
||||
"KEYWORD_DATA_SOURCE_TYPE",
|
||||
"LOGIN_LOCKOUT_DURATION",
|
||||
"LOG_FORMAT",
|
||||
"OPENDAL_FS_ROOT",
|
||||
"OPENDAL_S3_ACCESS_KEY_ID",
|
||||
"OPENDAL_S3_BUCKET",
|
||||
"OPENDAL_S3_ENDPOINT",
|
||||
"OPENDAL_S3_REGION",
|
||||
"OPENDAL_S3_ROOT",
|
||||
"OPENDAL_S3_SECRET_ACCESS_KEY",
|
||||
"OPENDAL_S3_SERVER_SIDE_ENCRYPTION",
|
||||
"PGVECTOR_MAX_CONNECTION",
|
||||
"PGVECTOR_MIN_CONNECTION",
|
||||
"PGVECTO_RS_DATABASE",
|
||||
"PGVECTO_RS_HOST",
|
||||
"PGVECTO_RS_PASSWORD",
|
||||
"PGVECTO_RS_PORT",
|
||||
"PGVECTO_RS_USER",
|
||||
"RESPECT_XFORWARD_HEADERS_ENABLED",
|
||||
"SCARF_NO_ANALYTICS",
|
||||
"SSRF_DEFAULT_CONNECT_TIME_OUT",
|
||||
"SSRF_DEFAULT_MAX_RETRIES",
|
||||
"SSRF_DEFAULT_READ_TIME_OUT",
|
||||
"SSRF_DEFAULT_TIME_OUT",
|
||||
"SSRF_DEFAULT_WRITE_TIME_OUT",
|
||||
"STORAGE_OPENDAL_SCHEME",
|
||||
"SUPABASE_API_KEY",
|
||||
"SUPABASE_BUCKET_NAME",
|
||||
"SUPABASE_URL",
|
||||
"USING_UGC_INDEX",
|
||||
"VIKINGDB_CONNECTION_TIMEOUT",
|
||||
"VIKINGDB_SOCKET_TIMEOUT",
|
||||
"WEAVIATE_BATCH_SIZE",
|
||||
"WEAVIATE_GRPC_ENABLED",
|
||||
}
|
||||
|
||||
API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys())
|
||||
DOCKER_CONFIG_SET = set(dotenv_values(Path("docker") / Path(".env.example")).keys())
|
||||
DOCKER_COMPOSE_CONFIG_SET = set()
|
||||
|
||||
with open(Path("docker") / Path("docker-compose.yaml")) as f:
|
||||
DOCKER_COMPOSE_CONFIG_SET = set(yaml.safe_load(f.read())["x-shared-env"].keys())
|
||||
|
||||
|
||||
def test_yaml_config():
|
||||
# python set == operator is used to compare two sets
|
||||
DIFF_API_WITH_DOCKER = (
|
||||
API_CONFIG_SET - DOCKER_CONFIG_SET - BASE_API_AND_DOCKER_CONFIG_SET_DIFF
|
||||
)
|
||||
if DIFF_API_WITH_DOCKER:
|
||||
print(
|
||||
f"API and Docker config sets are different with key: {DIFF_API_WITH_DOCKER}"
|
||||
)
|
||||
raise Exception("API and Docker config sets are different")
|
||||
DIFF_API_WITH_DOCKER_COMPOSE = (
|
||||
API_CONFIG_SET
|
||||
- DOCKER_COMPOSE_CONFIG_SET
|
||||
- BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF
|
||||
)
|
||||
if DIFF_API_WITH_DOCKER_COMPOSE:
|
||||
print(
|
||||
f"API and Docker Compose config sets are different with key: {DIFF_API_WITH_DOCKER_COMPOSE}"
|
||||
)
|
||||
raise Exception("API and Docker Compose config sets are different")
|
||||
print("All tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_yaml_config()
|
||||
@@ -107,6 +107,7 @@ ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
# The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer.
|
||||
APP_MAX_ACTIVE_REQUESTS=0
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
# ------------------------------
|
||||
# Container Startup Related Configuration
|
||||
@@ -606,6 +607,7 @@ UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||
# Sentry Configuration
|
||||
# Used for application monitoring and error log tracking.
|
||||
# ------------------------------
|
||||
SENTRY_DSN=
|
||||
|
||||
# API Service Sentry DSN address, default is empty, when empty,
|
||||
# all monitoring information is not reported to Sentry.
|
||||
|
||||
@@ -18,6 +18,7 @@ x-shared-env: &shared-api-worker-env
|
||||
LOG_DATEFORMAT: ${LOG_DATEFORMAT:-"%Y-%m-%d %H:%M:%S"}
|
||||
LOG_TZ: ${LOG_TZ:-UTC}
|
||||
DEBUG: ${DEBUG:-false}
|
||||
SENTRY_DSN: ${SENTRY_DSN:-}
|
||||
FLASK_DEBUG: ${FLASK_DEBUG:-false}
|
||||
SECRET_KEY: ${SECRET_KEY:-sk-9f73s3ljTXVcMT3Blb3ljTqtsKiGHXVcMT3BlbkFJLK7U}
|
||||
INIT_PASSWORD: ${INIT_PASSWORD:-}
|
||||
@@ -28,6 +29,7 @@ x-shared-env: &shared-api-worker-env
|
||||
FILES_ACCESS_TIMEOUT: ${FILES_ACCESS_TIMEOUT:-300}
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60}
|
||||
APP_MAX_ACTIVE_REQUESTS: ${APP_MAX_ACTIVE_REQUESTS:-0}
|
||||
APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-1200}
|
||||
DIFY_BIND_ADDRESS: ${DIFY_BIND_ADDRESS:-0.0.0.0}
|
||||
DIFY_PORT: ${DIFY_PORT:-5001}
|
||||
SERVER_WORKER_AMOUNT: ${SERVER_WORKER_AMOUNT:-}
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
'use client'
|
||||
import { useSearchParams } from 'next/navigation'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const Empty = () => {
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
|
||||
return (
|
||||
<div className='flex flex-col items-center'>
|
||||
<div className="shrink-0 w-[163px] h-[149px] bg-cover bg-no-repeat bg-[url('~@/app/components/tools/add-tool-modal/empty.png')]"></div>
|
||||
<div className='mb-1 text-[13px] font-medium text-text-primary leading-[18px]'>{t('tools.addToolModal.emptyTitle')}</div>
|
||||
<div className='text-[13px] text-text-tertiary leading-[18px]'>{t('tools.addToolModal.emptyTip')}</div>
|
||||
<div className='mb-1 text-[13px] font-medium text-text-primary leading-[18px]'>
|
||||
{t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTitle' : 'emptyTitleCustom'}`)}
|
||||
</div>
|
||||
<div className='text-[13px] text-text-tertiary leading-[18px]'>
|
||||
{t(`tools.addToolModal.${searchParams.get('category') === 'workflow' ? 'emptyTip' : 'emptyTipCustom'}`)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ const translation = {
|
||||
manageInTools: 'Manage in Tools',
|
||||
emptyTitle: 'No workflow tool available',
|
||||
emptyTip: 'Go to "Workflow -> Publish as Tool"',
|
||||
emptyTitleCustom: 'No custom tool available',
|
||||
emptyTipCustom: 'Create a custom tool',
|
||||
},
|
||||
createTool: {
|
||||
title: 'Create Custom Tool',
|
||||
|
||||
@@ -31,6 +31,8 @@ const translation = {
|
||||
manageInTools: '去工具列表管理',
|
||||
emptyTitle: '没有可用的工作流工具',
|
||||
emptyTip: '去 “工作流 -> 发布为工具” 添加',
|
||||
emptyTitleCustom: '没有可用的自定义工具',
|
||||
emptyTipCustom: '创建自定义工具',
|
||||
},
|
||||
createTool: {
|
||||
title: '创建自定义工具',
|
||||
|
||||
Reference in New Issue
Block a user