Compare commits

...

17 Commits

Author SHA1 Message Date
CodingOnStar
aa5a22991b refactor(toast): streamline toast component structure and improve cleanup logic
- Adjusted class names for consistency in the Toast component.
- Refactored the toastHandler.clear function to improve cleanup logic by using a dedicated unmountAndRemove function.
- Ensured proper handling of the timer for toast notifications.
2026-03-02 11:54:51 +08:00
CodingOnStar
4928917878 Merge remote-tracking branch 'origin/main' into refactor/base-comp 2026-03-02 11:51:03 +08:00
Coding On Star
335b500aea test: add unit tests for base components (#32818)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-02 11:40:43 +08:00
CodingOnStar
b00afff61e fix(tests): correct import paths in chat and context block test files 2026-03-02 11:31:59 +08:00
CodingOnStar
691248f477 test: add unit tests for various components including Alert, AppUnavailable, Badge, ThemeSelector, ThemeSwitcher, ActionButton, and AgentLogModal 2026-03-02 11:11:08 +08:00
hj24
8cc775d9f2 fix: optimize workflow_run iter query (#32815) 2026-03-02 11:01:11 +08:00
yyh
1a33903887 feat(web): add root isolation layer for portal stacking context (#32807) 2026-03-02 10:59:56 +08:00
edvatar
00dbaef04f fix: use declared_attr.directive for WorkflowNodeExecutionModel.__table_args__ (#32656)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-02 10:15:06 +08:00
edvatar
248202c220 fix: remove references to non-existent Document attributes in test (#32654)
Signed-off-by: edvatar <88481784+toroleapinc@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-03-02 10:14:56 +08:00
HanWenbo
691c9911c7 fix(ci): make pyrefly diff comments focus on diagnostics (#32778) 2026-03-02 10:11:23 +08:00
Copilot
baeea77c5b fix: typo in WebScraper plugin description: "Scrapper" → "Scraper" (#32790)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-02 10:10:28 +08:00
wangxiaolei
9da98e6c6c fix: fix import error (#32800) 2026-03-02 08:59:53 +08:00
99
a01de98721 refactor(workflow): decouple start node external dependencies (#32793)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-02 02:01:41 +08:00
-LAN-
17c1538e03 refactor(workflow): move PromptMessageMemory to model_runtime.memory (#32796)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 01:58:02 +08:00
-LAN-
69b3e94630 refactor: inject workflow node memory via protocol (#32784) 2026-03-02 01:55:49 +08:00
-LAN-
ef2b5d6107 refactor(api): move llm quota deduction to app graph layer (#32786) 2026-03-01 23:25:36 +08:00
非法操作
fa4b8910c8 chore: support code-inspector for vinext (#32788)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-01 20:27:57 +08:00
452 changed files with 1725 additions and 1184 deletions

View File

@@ -29,16 +29,22 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Prepare diagnostics extractor
run: |
git show ${{ github.event.pull_request.head.sha }}:api/libs/pyrefly_diagnostics.py > /tmp/pyrefly_diagnostics.py
- name: Run pyrefly on PR branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_pr.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_pr.txt || true
- name: Checkout base branch
run: git checkout ${{ github.base_ref }}
- name: Run pyrefly on base branch
run: |
uv run --directory api pyrefly check > /tmp/pyrefly_base.txt 2>&1 || true
uv run --directory api --dev pyrefly check 2>&1 \
| uv run --directory api python /tmp/pyrefly_diagnostics.py > /tmp/pyrefly_base.txt || true
- name: Compute diff
run: |

View File

@@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@@ -52,7 +54,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
# TODO(QuantumGhost): use DI to avoid depending on global DB.
@@ -107,14 +108,11 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@@ -131,12 +129,10 @@ ignore_imports =
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@@ -150,7 +146,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util
core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities
@@ -172,7 +167,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
@@ -180,7 +174,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services

View File

@@ -8,9 +8,9 @@ from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import mcp_ns
from core.app.app_config.entities import VariableEntity
from core.mcp import types as mcp_types
from core.mcp.server.streamable_http import handle_mcp_request
from core.workflow.variables.input_entities import VariableEntity
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -1,7 +1,8 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.app.app_config.entities import ExternalDataVariableEntity
from core.external_data_tool.factory import ExternalDataToolFactory
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
[

View File

@@ -2,12 +2,12 @@ from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.file import FileTransferMethod, FileType, FileUploadConfig
from core.workflow.file import FileUploadConfig
from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity
from models.model import AppMode
@@ -90,61 +90,7 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: AdvancedCompletionPromptTemplateEntity | None = None
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, v: Any) -> str:
return v or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, v: Any) -> Sequence[str]:
return v or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema
class RagPipelineVariableEntity(VariableEntity):
class RagPipelineVariableEntity(WorkflowVariableEntity):
"""
Rag Pipeline Variable Entity.
"""
@@ -314,7 +260,7 @@ class AppConfig(BaseModel):
app_id: str
app_mode: AppMode
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
variables: list[WorkflowVariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@@ -1,6 +1,7 @@
import re
from core.app.app_config.entities import RagPipelineVariableEntity, VariableEntity
from core.app.app_config.entities import RagPipelineVariableEntity
from core.workflow.variables.input_entities import VariableEntity
from models.workflow import Workflow

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Union, final
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.enums import NodeType
from core.workflow.file import File, FileUploadConfig
@@ -12,13 +11,14 @@ from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
NoopDraftVariableSaver,
)
from core.workflow.variables.input_entities import VariableEntityType
from factories import file_factory
from libs.orjson import orjson_dumps
from models import Account, EndUser
from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
from core.workflow.variables.input_entities import VariableEntity
class BaseAppGenerator:

View File

@@ -1 +1,5 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

93
api/core/app/llm/quota.py Normal file
View File

@@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@@ -1,6 +1,8 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast, final
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import override
from configs import dify_config
@@ -11,14 +13,16 @@ from core.helper.code_executor.code_executor import (
CodeExecutor,
)
from core.helper.ssrf_proxy import ssrf_proxy
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.enums import NodeType, SystemVariableKey
from core.workflow.file.file_manager import file_manager
from core.workflow.graph.graph import NodeFactory
from core.workflow.nodes.base.node import Node
@@ -29,11 +33,9 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import PromptMessageMemory
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
@@ -41,12 +43,34 @@ from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.variables.segments import StringSegment
from extensions.ext_database import db
from models.model import Conversation
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
def fetch_memory(
*,
conversation_id: str | None,
app_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
) -> TokenBufferMemory | None:
if not node_data_memory or not conversation_id:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
class DefaultWorkflowCodeExecutor:
def execute(
self,
@@ -221,6 +245,7 @@ class DifyNodeFactory(NodeFactory):
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@@ -229,10 +254,12 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@@ -241,6 +268,7 @@ class DifyNodeFactory(NodeFactory):
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
)
return node_class(
@@ -295,8 +323,14 @@ class DifyNodeFactory(NodeFactory):
return None
node_memory = MemoryConfig.model_validate(raw_memory_config)
return llm_utils.fetch_memory(
variable_pool=self.graph_runtime_state.variable_pool,
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
conversation_id = (
conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None
)
return fetch_memory(
conversation_id=conversation_id,
app_id=self.graph_init_params.app_id,
node_data_memory=node_memory,
model_instance=model_instance,

View File

@@ -4,10 +4,10 @@ from collections.abc import Mapping
from typing import Any, cast
from configs import dify_config
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types as mcp_types
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService

View File

@@ -0,0 +1,3 @@
from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory
__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"]

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Protocol
from core.model_runtime.entities import PromptMessage
DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages."""
def get_history_prompt_messages(
self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,

View File

@@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@@ -6,9 +6,9 @@ identity:
zh_Hans: 网页抓取
pt_BR: WebScraper
description:
en_US: Web Scrapper tool kit is used to scrape web
en_US: Web Scraper tool kit is used to scrape web
zh_Hans: 一个用于抓取网页的工具。
pt_BR: Web Scrapper tool kit is used to scrape web
pt_BR: Web Scraper tool kit is used to scrape web
icon: icon.svg
tags:
- productivity

View File

@@ -1,11 +1,11 @@
from collections.abc import Mapping, Sequence
from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.variables.input_entities import VariableEntity
class WorkflowToolConfigurationUtils:

View File

@@ -5,7 +5,6 @@ from collections.abc import Mapping
from pydantic import Field
from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.db.session_factory import session_factory
from core.plugin.entities.parameters import PluginParameterOption
@@ -23,6 +22,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.tool import WorkflowTool
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode

View File

@@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -1,26 +1,19 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities import PromptMessageRole
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from .exc import InvalidVariableTypeError
@@ -48,88 +41,51 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
def convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
used_quota = 1
continue
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def fetch_memory_text(
*,
memory: PromptMessageMemory,
max_token_limit: int,
message_limit: int | None = None,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
) -> str:
history_messages = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
return convert_history_messages_to_text(
history_messages=history_messages,
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)

View File

@@ -37,6 +37,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -62,7 +63,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory, PromptMessageMemory
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from core.workflow.variables import (
ArrayFileSegment,
@@ -278,8 +279,6 @@ class LLMNode(Node[LLMNodeData]):
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
@@ -1234,6 +1233,10 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
@@ -1336,48 +1339,16 @@ def _handle_memory_completion_mode(
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_messages = memory.get_history_prompt_messages(
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
memory_text = _convert_history_messages_to_text(
history_messages=memory_messages,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _convert_history_messages_to_text(
*,
history_messages: Sequence[PromptMessage],
human_prefix: str,
ai_prefix: str,
) -> str:
string_messages: list[str] = []
for message in history_messages:
if message.role == PromptMessageRole.USER:
role = human_prefix
elif message.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(message.content, list):
content_parts = []
for content in message.content:
if isinstance(content, TextPromptMessageContent):
content_parts.append(content.data)
elif isinstance(content, ImagePromptMessageContent):
content_parts.append("[image]")
inner_msg = "\n".join(content_parts)
string_messages.append(f"{role}: {inner_msg}")
else:
string_messages.append(f"{role}: {message.content}")
return "\n".join(string_messages)
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,

View File

@@ -1,10 +1,8 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
from core.model_manager import ModelInstance
from core.model_runtime.entities import PromptMessage
class CredentialsProvider(Protocol):
@@ -21,13 +19,3 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class PromptMessageMemory(Protocol):
"""Port for loading memory as prompt messages for LLM nodes."""
def get_history_prompt_messages(
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""Return historical prompt messages constrained by token/message limits."""
...

View File

@@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -5,7 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -18,13 +17,18 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
@@ -97,6 +101,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_memory: PromptMessageMemory | None
def __init__(
self,
@@ -108,6 +113,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
) -> None:
super().__init__(
id=id,
@@ -118,6 +124,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -163,13 +170,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
if (
set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}
@@ -308,9 +309,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
@@ -319,7 +317,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@@ -407,7 +405,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -445,7 +443,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -470,7 +468,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
files=files,
context="",
memory_config=node_data.memory,
memory=memory,
# AdvancedPromptTransform is still typed against TokenBufferMemory.
memory=cast(Any, memory),
model_instance=model_instance,
image_detail_config=vision_detail,
)
@@ -483,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@@ -715,7 +714,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@@ -724,8 +723,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@@ -742,7 +741,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@@ -751,8 +750,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
instruction = variable_pool.convert_template(node_data.instruction or "").text
if memory and node_data.memory and node_data.memory.window:
memory_str = memory.get_history_prompt_text(
max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
memory_str = llm_utils.fetch_memory_text(
memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size
)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
@@ -828,6 +827,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -3,9 +3,9 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.memory import PromptMessageMemory
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -56,6 +56,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
def __init__(
self,
@@ -67,6 +68,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -81,6 +83,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -103,13 +106,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model instance
model_instance = self._model_instance
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
memory = self._memory
# fetch instruction
node_data.instruction = node_data.instruction or ""
node_data.instruction = variable_pool.convert_template(node_data.instruction).text
@@ -240,6 +237,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@@ -323,7 +324,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: TokenBufferMemory | None,
memory: PromptMessageMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)
@@ -336,7 +337,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
input_text = query
memory_str = ""
if memory:
memory_str = memory.get_history_prompt_text(
memory_str = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None,
)

View File

@@ -2,8 +2,8 @@ from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.nodes.base import BaseNodeData
from core.workflow.variables.input_entities import VariableEntity
class StartNodeData(BaseNodeData):

View File

@@ -2,12 +2,12 @@ from typing import Any
from jsonschema import Draft7Validator, ValidationError
from core.app.app_config.entities import VariableEntityType
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.variables.input_entities import VariableEntityType
class StartNode(Node[StartNodeData]):

View File

@@ -1,3 +1,4 @@
from .input_entities import VariableEntity, VariableEntityType
from .segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
@@ -64,4 +65,6 @@ __all__ = [
"StringVariable",
"Variable",
"VariableBase",
"VariableEntity",
"VariableEntityType",
]

View File

@@ -0,0 +1,62 @@
from collections.abc import Sequence
from enum import StrEnum
from typing import Any
from jsonschema import Draft7Validator, SchemaError
from pydantic import BaseModel, Field, field_validator
from core.workflow.file import FileTransferMethod, FileType
class VariableEntityType(StrEnum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external_data_tool"
FILE = "file"
FILE_LIST = "file-list"
CHECKBOX = "checkbox"
JSON_OBJECT = "json_object"
class VariableEntity(BaseModel):
"""
Shared variable entity used by workflow runtime and app configuration.
"""
# `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
type: VariableEntityType
required: bool = False
hide: bool = False
default: Any = None
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
def convert_none_description(cls, value: Any) -> str:
return value or ""
@field_validator("options", mode="before")
@classmethod
def convert_none_options(cls, value: Any) -> Sequence[str]:
return value or []
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
Draft7Validator.check_schema(schema)
except SchemaError as error:
raise ValueError(f"Invalid JSON schema: {error.message}")
return schema

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@@ -106,6 +107,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@@ -0,0 +1,48 @@
"""Helpers for producing concise pyrefly diagnostics for CI diff output."""
from __future__ import annotations
import sys
_DIAGNOSTIC_PREFIXES = ("ERROR ", "WARNING ")
_LOCATION_PREFIX = "-->"
def extract_diagnostics(raw_output: str) -> str:
"""Extract stable diagnostic lines from pyrefly output.
The full pyrefly output includes code excerpts and carets, which create noisy
diffs. This helper keeps only:
- diagnostic headline lines (``ERROR ...`` / ``WARNING ...``)
- the following location line (``--> path:line:column``), when present
"""
lines = raw_output.splitlines()
diagnostics: list[str] = []
for index, line in enumerate(lines):
if line.startswith(_DIAGNOSTIC_PREFIXES):
diagnostics.append(line.rstrip())
next_index = index + 1
if next_index < len(lines):
next_line = lines[next_index]
if next_line.lstrip().startswith(_LOCATION_PREFIX):
diagnostics.append(next_line.rstrip())
if not diagnostics:
return ""
return "\n".join(diagnostics) + "\n"
def main() -> int:
"""Read pyrefly output from stdin and print normalized diagnostics."""
raw_output = sys.stdin.read()
sys.stdout.write(extract_diagnostics(raw_output))
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -787,7 +787,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (

View File

@@ -29,7 +29,7 @@ from typing import Any, cast
import sqlalchemy as sa
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, null, or_, select
from sqlalchemy import and_, delete, func, null, or_, select, tuple_
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, selectinload, sessionmaker
@@ -423,9 +423,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
if last_seen:
stmt = stmt.where(
or_(
WorkflowRun.created_at > last_seen[0],
and_(WorkflowRun.created_at == last_seen[0], WorkflowRun.id > last_seen[1]),
tuple_(WorkflowRun.created_at, WorkflowRun.id)
> tuple_(
sa.literal(last_seen[0], type_=sa.DateTime()),
sa.literal(last_seen[1], type_=WorkflowRun.id.type),
)
)

View File

@@ -8,7 +8,6 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
)
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
@@ -20,6 +19,7 @@ from core.prompt.simple_prompt_transform import SimplePromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.file.models import FileUploadConfig
from core.workflow.nodes import NodeType
from core.workflow.variables.input_entities import VariableEntity
from events.app_event import app_was_created
from extensions.ext_database import db
from models import Account

View File

@@ -9,7 +9,6 @@ from sqlalchemy import exists, select
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -40,6 +39,7 @@ from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import load_into_variable_pool
from core.workflow.variables import VariableBase
from core.workflow.variables.input_entities import VariableEntityType
from core.workflow.variables.variables import Variable
from core.workflow.workflow_entry import WorkflowEntry
from enums.cloud_plan import CloudPlan

View File

@@ -5,7 +5,7 @@ from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
@@ -22,19 +22,17 @@ from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_mod
def get_mocked_fetch_memory(memory_text: str):
class MemoryMock:
def get_history_prompt_text(
def get_history_prompt_messages(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
):
return memory_text
return [UserPromptMessage(content=memory_text), AssistantPromptMessage(content="mocked answer")]
return MagicMock(return_value=MemoryMock())
def init_parameter_extractor_node(config: dict):
def init_parameter_extractor_node(config: dict, memory=None):
graph_config = {
"edges": [
{
@@ -79,6 +77,7 @@ def init_parameter_extractor_node(config: dict):
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
memory=memory,
)
return node
@@ -350,7 +349,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
def test_chat_parameter_extractor_with_memory(setup_model_mock):
"""
Test chat parameter extractor with memory.
"""
@@ -373,6 +372,7 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
"memory": {"window": {"enabled": True, "size": 50}},
},
},
memory=get_mocked_fetch_memory("customized memory")(),
)
node._model_instance = get_mocked_fetch_model_instance(
@@ -381,8 +381,6 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()
result = node._run()

View File

@@ -10,11 +10,10 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.model_runtime.entities.llm_entities import LLMMode
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models import Account, Tenant
from models.api_based_extension import APIBasedExtension
from models.model import App, AppMode, AppModelConfig

View File

@@ -147,8 +147,7 @@ class TestDisableSegmentsFromIndexTask:
document.cleaning_completed_at = fake.date_time_this_year()
document.splitting_completed_at = fake.date_time_this_year()
document.tokens = fake.random_int(min=50, max=500)
document.indexing_started_at = fake.date_time_this_year()
document.indexing_completed_at = fake.date_time_this_year()
document.completed_at = fake.date_time_this_year()
document.indexing_status = "completed"
document.enabled = True
document.archived = False

View File

@@ -1,7 +1,7 @@
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.base_app_generator import BaseAppGenerator
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def test_validate_inputs_with_zero():

View File

@@ -4,7 +4,6 @@ from unittest.mock import Mock, patch
import jsonschema
import pytest
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
from core.mcp import types
from core.mcp.server.streamable_http import (
@@ -19,6 +18,7 @@ from core.mcp.server.streamable_http import (
prepare_tool_arguments,
process_mapping_response,
)
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.model import App, AppMCPServer, AppMode, EndUser

View File

@@ -0,0 +1,174 @@
import threading
from datetime import datetime
from unittest.mock import MagicMock, patch
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.commands import CommandType
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
llm_usage=LLMUsage.empty_usage(),
),
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.START
node.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "No credits remaining"
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
layer.command_channel.send_command.assert_not_called()

View File

@@ -4,12 +4,12 @@ import time
import pytest
from pydantic import ValidationError as PydanticValidationError
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
def make_start_node(user_inputs, variables):

View File

@@ -0,0 +1,51 @@
from libs.pyrefly_diagnostics import extract_diagnostics
def test_extract_diagnostics_keeps_only_summary_and_location_lines() -> None:
# Arrange
raw_output = """INFO Checking project configured at `/tmp/project/pyrefly.toml`
ERROR `result` may be uninitialized [unbound-name]
--> controllers/console/app/annotation.py:126:16
|
126 | return result, 200
| ^^^^^^
|
ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]
--> controllers/console/app/app.py:574:13
|
574 | app_model.access_mode = app_setting.access_mode
| ^^^^^^^^^^^^^^^^^^^^^
"""
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == (
"ERROR `result` may be uninitialized [unbound-name]\n"
" --> controllers/console/app/annotation.py:126:16\n"
"ERROR Object of class `App` has no attribute `access_mode` [missing-attribute]\n"
" --> controllers/console/app/app.py:574:13\n"
)
def test_extract_diagnostics_handles_error_without_location_line() -> None:
# Arrange
raw_output = "ERROR unexpected pyrefly output format [bad-format]\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == "ERROR unexpected pyrefly output format [bad-format]\n"
def test_extract_diagnostics_returns_empty_for_non_error_output() -> None:
# Arrange
raw_output = "INFO Checking project configured at `/tmp/project/pyrefly.toml`\n"
# Act
diagnostics = extract_diagnostics(raw_output)
# Assert
assert diagnostics == ""

View File

@@ -13,12 +13,11 @@ from core.app.app_config.entities import (
ExternalDataVariableEntity,
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.variables.input_entities import VariableEntity, VariableEntityType
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
from models.model import AppMode
from services.workflow.workflow_converter import WorkflowConverter

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Alert from './alert'
import Alert from '../alert'
describe('Alert', () => {
const defaultProps = {

View File

@@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import AppUnavailable from './app-unavailable'
import AppUnavailable from '../app-unavailable'
describe('AppUnavailable', () => {
beforeEach(() => {

View File

@@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import Badge from './badge'
import Badge from '../badge'
describe('Badge', () => {
describe('Rendering', () => {

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSelector from './theme-selector'
import ThemeSelector from '../theme-selector'
// Mock next-themes with controllable state
let mockTheme = 'system'

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import ThemeSwitcher from './theme-switcher'
import ThemeSwitcher from '../theme-switcher'
let mockTheme = 'system'
const mockSetTheme = vi.fn()

View File

@@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import { ActionButton, ActionButtonState } from './index'
import { ActionButton, ActionButtonState } from '../index'
describe('ActionButton', () => {
it('renders button with default props', () => {

View File

@@ -4,7 +4,7 @@ import type { AgentLogDetailResponse } from '@/models/log'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAgentLogDetail } from '@/service/log'
import AgentLogDetail from './detail'
import AgentLogDetail from '../detail'
vi.mock('@/service/log', () => ({
fetchAgentLogDetail: vi.fn(),

View File

@@ -3,7 +3,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useClickAway } from 'ahooks'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAgentLogDetail } from '@/service/log'
import AgentLogModal from './index'
import AgentLogModal from '../index'
vi.mock('@/service/log', () => ({
fetchAgentLogDetail: vi.fn(),

View File

@@ -1,6 +1,6 @@
import type { AgentIteration } from '@/models/log'
import { render, screen } from '@testing-library/react'
import Iteration from './iteration'
import Iteration from '../iteration'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@@ -1,6 +1,6 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import ResultPanel from './result'
import ResultPanel from '../result'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@@ -2,7 +2,7 @@ import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import { BlockEnum } from '@/app/components/workflow/types'
import ToolCallItem from './tool-call'
import ToolCallItem from '../tool-call'
vi.mock('@/app/components/workflow/nodes/_base/components/editor/code-editor', () => ({
default: ({ title, value }: { title: React.ReactNode, value: string | object }) => (

View File

@@ -1,7 +1,7 @@
import type { AgentIteration } from '@/models/log'
import { render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import TracingPanel from './tracing'
import TracingPanel from '../tracing'
vi.mock('@/app/components/workflow/block-icon', () => ({
default: () => <div data-testid="block-icon" />,

View File

@@ -1,5 +1,5 @@
import { render, screen } from '@testing-library/react'
import AnswerIcon from '.'
import AnswerIcon from '..'
describe('AnswerIcon', () => {
it('renders default emoji when no icon or image is provided', () => {

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ImageInput from './ImageInput'
import ImageInput from '../ImageInput'
const createObjectURLMock = vi.fn(() => 'blob:mock-url')
const revokeObjectURLMock = vi.fn()

View File

@@ -1,5 +1,5 @@
import { act, renderHook } from '@testing-library/react'
import { useDraggableUploader } from './hooks'
import { useDraggableUploader } from '../hooks'
type MockDragEventOverrides = {
dataTransfer?: { files: File[] }

View File

@@ -3,7 +3,7 @@ import type { ImageFile } from '@/types/app'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { TransferMethod } from '@/types/app'
import AppIconPicker from './index'
import AppIconPicker from '../index'
import 'vitest-canvas-mock'
type LocalFileUploaderOptions = {
@@ -93,7 +93,7 @@ vi.mock('react-easy-crop', () => ({
),
}))
vi.mock('../image-uploader/hooks', () => ({
vi.mock('../../image-uploader/hooks', () => ({
useLocalFileUploader: (options: LocalFileUploaderOptions) => {
mocks.onUpload = options.onUpload
return { handleLocalFileUpload: mocks.handleLocalFileUpload }

View File

@@ -1,4 +1,4 @@
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from './utils'
import getCroppedImg, { checkIsAnimatedImage, createImage, getMimeType, getRadianAngle, rotateSize } from '../utils'
type ImageLoadEventType = 'load' | 'error'

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import AppIcon from './index'
import AppIcon from '../index'
// Mock emoji-mart initialization
vi.mock('emoji-mart', () => ({

View File

@@ -2,7 +2,7 @@ import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import i18next from 'i18next'
import { useParams, usePathname } from 'next/navigation'
import AudioBtn from './index'
import AudioBtn from '../index'
const mockPlayAudio = vi.fn()
const mockPauseAudio = vi.fn()

View File

@@ -4,7 +4,7 @@ import { vi } from 'vitest'
import useThemeMock from '@/hooks/use-theme'
import { Theme } from '@/types/app'
import AudioPlayer from './AudioPlayer'
import AudioPlayer from '../AudioPlayer'
vi.mock('@/hooks/use-theme', () => ({
default: vi.fn(() => ({ theme: 'light' })),

View File

@@ -3,12 +3,12 @@ import * as React from 'react'
// AudioGallery.spec.tsx
import { describe, expect, it, vi } from 'vitest'
import AudioGallery from './index'
import AudioGallery from '../index'
// Mock AudioPlayer so we only assert prop forwarding
const audioPlayerMock = vi.fn()
vi.mock('./AudioPlayer', () => ({
vi.mock('../AudioPlayer', () => ({
default: (props: { srcs: string[] }) => {
audioPlayerMock(props)
return <div data-testid="audio-player" />

View File

@@ -1,6 +1,6 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { sleep } from '@/utils'
import AutoHeightTextarea from './index'
import AutoHeightTextarea from '../index'
vi.mock('@/utils', async () => {
const actual = await vi.importActual('@/utils')

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import Avatar from './index'
import Avatar from '../index'
describe('Avatar', () => {
beforeEach(() => {

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import Badge, { BadgeState, BadgeVariants } from './index'
import Badge, { BadgeState, BadgeVariants } from '../index'
describe('Badge', () => {
describe('Rendering', () => {

View File

@@ -1,7 +1,7 @@
import { cleanup, fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import Toast from '@/app/components/base/toast'
import BlockInput, { getInputKeys } from './index'
import BlockInput, { getInputKeys } from '../index'
vi.mock('@/utils/var', () => ({
checkKeys: vi.fn((_keys: string[]) => ({

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import AddButton from './add-button'
import AddButton from '../add-button'
describe('AddButton', () => {
describe('Rendering', () => {

View File

@@ -1,6 +1,6 @@
import { cleanup, fireEvent, render } from '@testing-library/react'
import * as React from 'react'
import Button from './index'
import Button from '../index'
afterEach(cleanup)
// https://testing-library.com/docs/queries/about

View File

@@ -1,5 +1,5 @@
import { fireEvent, render, screen } from '@testing-library/react'
import SyncButton from './sync-button'
import SyncButton from '../sync-button'
describe('SyncButton', () => {
describe('Rendering', () => {

View File

@@ -1,7 +1,7 @@
import type { Mock } from 'vitest'
import { act, fireEvent, render, screen } from '@testing-library/react'
import useEmblaCarousel from 'embla-carousel-react'
import { Carousel, useCarousel } from './index'
import { Carousel, useCarousel } from '../index'
vi.mock('embla-carousel-react', () => ({
default: vi.fn(),

View File

@@ -1,5 +1,5 @@
import type { ChatConfig, ChatItemInTree } from '../types'
import type { ChatWithHistoryContextValue } from './context'
import type { ChatConfig, ChatItemInTree } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import type { HumanInputFormData } from '@/types/workflow'
@@ -12,17 +12,17 @@ import {
stopChatMessageResponding,
} from '@/service/share'
import { TransferMethod } from '@/types/app'
import { useChat } from '../chat/hooks'
import { useChat } from '../../chat/hooks'
import { isValidGeneratedAnswer } from '../utils'
import ChatWrapper from './chat-wrapper'
import { useChatWithHistoryContext } from './context'
import { isValidGeneratedAnswer } from '../../utils'
import ChatWrapper from '../chat-wrapper'
import { useChatWithHistoryContext } from '../context'
vi.mock('../chat/hooks', () => ({
vi.mock('../../chat/hooks', () => ({
useChat: vi.fn(),
}))
vi.mock('./context', () => ({
vi.mock('../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))
@@ -37,7 +37,7 @@ vi.mock('next/navigation', () => ({
useParams: vi.fn(() => ({ token: 'test-token' })),
}))
vi.mock('../utils', () => ({
vi.mock('../../utils', () => ({
isValidGeneratedAnswer: vi.fn(),
getLastAnswer: vi.fn(),
}))

View File

@@ -1,12 +1,12 @@
import type { ChatConfig } from '../types'
import type { ChatWithHistoryContextValue } from './context'
import type { ChatConfig } from '../../types'
import type { ChatWithHistoryContextValue } from '../context'
import type { AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useChatWithHistoryContext } from './context'
import HeaderInMobile from './header-in-mobile'
import { useChatWithHistoryContext } from '../context'
import HeaderInMobile from '../header-in-mobile'
vi.mock('@/hooks/use-breakpoints', () => ({
default: vi.fn(),
@@ -17,7 +17,7 @@ vi.mock('@/hooks/use-breakpoints', () => ({
},
}))
vi.mock('./context', () => ({
vi.mock('../context', () => ({
useChatWithHistoryContext: vi.fn(),
ChatWithHistoryContext: { Provider: ({ children }: { children: React.ReactNode }) => <div>{children}</div> },
}))
@@ -33,7 +33,7 @@ vi.mock('next/navigation', () => ({
useParams: vi.fn(() => ({})),
}))
vi.mock('../embedded-chatbot/theme/theme-context', () => ({
vi.mock('../../embedded-chatbot/theme/theme-context', () => ({
useThemeContext: vi.fn(() => ({
buildTheme: vi.fn(),
})),

View File

@@ -1,5 +1,5 @@
import type { ReactNode } from 'react'
import type { ChatConfig } from '../types'
import type { ChatConfig } from '../../types'
import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { act, renderHook, waitFor } from '@testing-library/react'
@@ -11,8 +11,8 @@ import {
generationConversationName,
} from '@/service/share'
import { shareQueryKeys } from '@/service/use-share'
import { CONVERSATION_ID_INFO } from '../constants'
import { useChatWithHistory } from './hooks'
import { CONVERSATION_ID_INFO } from '../../constants'
import { useChatWithHistory } from '../hooks'
vi.mock('@/hooks/use-app-favicon', () => ({
useAppFavicon: vi.fn(),
@@ -40,8 +40,8 @@ vi.mock('@/context/web-app-context', () => ({
useWebAppStore: (selector?: (state: typeof mockStoreState) => unknown) => useWebAppStoreMock(selector),
}))
vi.mock('../utils', async () => {
const actual = await vi.importActual<typeof import('../utils')>('../utils')
vi.mock('../../utils', async () => {
const actual = await vi.importActual<typeof import('../../utils')>('../../utils')
return {
...actual,
getProcessedSystemVariablesFromUrlParams: vi.fn().mockResolvedValue({ user_id: 'user-1' }),

View File

@@ -1,5 +1,5 @@
import type { RefObject } from 'react'
import type { ChatConfig } from '../types'
import type { ChatConfig } from '../../types'
import type { InstalledApp } from '@/models/explore'
import type { AppConversationData, AppData, AppMeta, ConversationItem } from '@/models/share'
import { fireEvent, render, screen } from '@testing-library/react'
@@ -7,11 +7,11 @@ import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import useDocumentTitle from '@/hooks/use-document-title'
import { useChatWithHistory } from './hooks'
import ChatWithHistory from './index'
import { useChatWithHistory } from '../hooks'
import ChatWithHistory from '../index'
// --- Mocks ---
vi.mock('./hooks', () => ({
vi.mock('../hooks', () => ({
useChatWithHistory: vi.fn(),
}))
@@ -40,7 +40,7 @@ vi.mock('next/navigation', () => ({
}))
const mockBuildTheme = vi.fn()
vi.mock('../embedded-chatbot/theme/theme-context', () => ({
vi.mock('../../embedded-chatbot/theme/theme-context', () => ({
useThemeContext: vi.fn(() => ({
buildTheme: mockBuildTheme,
})),

View File

@@ -1,13 +1,13 @@
import type { ChatWithHistoryContextValue } from '../context'
import type { ChatWithHistoryContextValue } from '../../context'
import type { AppData, ConversationItem } from '@/models/share'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useChatWithHistoryContext } from '../context'
import Header from './index'
import { useChatWithHistoryContext } from '../../context'
import Header from '../index'
// Mock context module
vi.mock('../context', () => ({
vi.mock('../../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))

View File

@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import MobileOperationDropdown from './mobile-operation-dropdown'
import MobileOperationDropdown from '../mobile-operation-dropdown'
describe('MobileOperationDropdown Component', () => {
const defaultProps = {

View File

@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Operation from './operation'
import Operation from '../operation'
describe('Operation Component', () => {
const defaultProps = {

View File

@@ -1,10 +1,10 @@
import type { ChatWithHistoryContextValue } from '../context'
import type { ChatWithHistoryContextValue } from '../../context'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { InputVarType } from '@/app/components/workflow/types'
import InputsFormContent from './content'
import InputsFormContent from '../content'
// Keep lightweight mocks for non-base project components
vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
@@ -90,7 +90,7 @@ const createMockContext = (overrides: Partial<ChatWithHistoryContextValue> = {})
// Create a real context for testing to support controlled component behavior
const MockContext = React.createContext<ChatWithHistoryContextValue>(createMockContext())
vi.mock('../context', () => ({
vi.mock('../../context', () => ({
useChatWithHistoryContext: () => React.useContext(MockContext),
}))

View File

@@ -1,11 +1,11 @@
import type { ChatWithHistoryContextValue } from '../context'
import type { ChatWithHistoryContextValue } from '../../context'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { InputVarType } from '@/app/components/workflow/types'
import { useChatWithHistoryContext } from '../context'
import InputsFormNode from './index'
import { useChatWithHistoryContext } from '../../context'
import InputsFormNode from '../index'
// Mocks for components used by InputsFormContent (the real sibling)
vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
@@ -31,7 +31,7 @@ vi.mock('@/app/components/base/file-uploader', () => ({
),
}))
vi.mock('../context', () => ({
vi.mock('../../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))

View File

@@ -1,11 +1,11 @@
import type { ChatWithHistoryContextValue } from '../context'
import type { ChatWithHistoryContextValue } from '../../context'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { InputVarType } from '@/app/components/workflow/types'
import { useChatWithHistoryContext } from '../context'
import ViewFormDropdown from './view-form-dropdown'
import { useChatWithHistoryContext } from '../../context'
import ViewFormDropdown from '../view-form-dropdown'
// Mocks for components used by InputsFormContent (the real sibling)
vi.mock('@/app/components/workflow/nodes/_base/components/before-run-form/bool-input', () => ({
@@ -31,7 +31,7 @@ vi.mock('@/app/components/base/file-uploader', () => ({
),
}))
vi.mock('../context', () => ({
vi.mock('../../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))

View File

@@ -1,13 +1,13 @@
import type { ChatWithHistoryContextValue } from '../context'
import type { ChatWithHistoryContextValue } from '../../context'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useChatWithHistoryContext } from '../context'
import Sidebar from './index'
import { useChatWithHistoryContext } from '../../context'
import Sidebar from '../index'
// Mock List to allow us to trigger operations
vi.mock('./list', () => ({
vi.mock('../list', () => ({
default: ({ list, onOperate, title }: { list: Array<{ id: string, name: string }>, onOperate: (type: string, item: { id: string, name: string }) => void, title?: string }) => (
<div>
{title && <div>{title}</div>}
@@ -25,7 +25,7 @@ vi.mock('./list', () => ({
}))
// Mock context hook
vi.mock('../context', () => ({
vi.mock('../../context', () => ({
useChatWithHistoryContext: vi.fn(),
}))

View File

@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Item from './item'
import Item from '../item'
// Mock Operation to verify its usage
vi.mock('@/app/components/base/chat/chat-with-history/sidebar/operation', () => ({

View File

@@ -1,10 +1,10 @@
import { render, screen } from '@testing-library/react'
import * as React from 'react'
import { describe, expect, it, vi } from 'vitest'
import List from './list'
import List from '../list'
// Mock Item to verify its usage
vi.mock('./item', () => ({
vi.mock('../item', () => ({
default: ({ item }: { item: { name: string } }) => (
<div data-testid="mock-item">
{item.name}

View File

@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Operation from './operation'
import Operation from '../operation'
// Mock PortalToFollowElem components to render children in place
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({

View File

@@ -2,7 +2,7 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import RenameModal from './rename-modal'
import RenameModal from '../rename-modal'
describe('RenameModal', () => {
const defaultProps = {

View File

@@ -1,7 +1,7 @@
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { describe, expect, it, vi } from 'vitest'
import ContentSwitch from './content-switch'
import ContentSwitch from '../content-switch'
describe('ContentSwitch', () => {
const defaultProps = {

View File

@@ -1,9 +1,9 @@
import type { ChatItem } from '../types'
import type { ChatContextValue } from './context'
import type { ChatItem } from '../../types'
import type { ChatContextValue } from '../context'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { vi } from 'vitest'
import { ChatContextProvider, useChatContext } from './context'
import { ChatContextProvider, useChatContext } from '../context'
const TestConsumer = () => {
const context = useChatContext()

View File

@@ -1,10 +1,10 @@
import type { ChatConfig, ChatItem, OnSend } from '../types'
import type { ChatProps } from './index'
import type { ChatConfig, ChatItem, OnSend } from '../../types'
import type { ChatProps } from '../index'
import { act, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { useStore as useAppStore } from '@/app/components/app/store'
import Chat from './index'
import Chat from '../index'
// ─── Why each mock exists ─────────────────────────────────────────────────────
//
@@ -24,7 +24,7 @@ import Chat from './index'
// TryToAsk only uses Button (base), Divider (base), i18n (global mock).
// ─────────────────────────────────────────────────────────────────────────────
vi.mock('./answer', () => ({
vi.mock('../answer', () => ({
default: ({ item, responding }: { item: ChatItem, responding?: boolean }) => (
<div
data-testid="answer-item"
@@ -36,13 +36,13 @@ vi.mock('./answer', () => ({
),
}))
vi.mock('./question', () => ({
vi.mock('../question', () => ({
default: ({ item }: { item: ChatItem }) => (
<div data-testid="question-item" data-id={item.id}>{item.content}</div>
),
}))
vi.mock('./chat-input-area', () => ({
vi.mock('../chat-input-area', () => ({
default: ({ disabled, readonly }: { disabled?: boolean, readonly?: boolean }) => (
<div
data-testid="chat-input-area"

View File

@@ -1,5 +1,5 @@
import type { Theme } from '../embedded-chatbot/theme/theme-context'
import type { ChatConfig, ChatItem, OnRegenerate } from '../types'
import type { Theme } from '../../embedded-chatbot/theme/theme-context'
import type { ChatConfig, ChatItem, OnRegenerate } from '../../types'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
@@ -7,10 +7,10 @@ import copy from 'copy-to-clipboard'
import * as React from 'react'
import { vi } from 'vitest'
import Toast from '../../toast'
import { ThemeBuilder } from '../embedded-chatbot/theme/theme-context'
import { ChatContextProvider } from './context'
import Question from './question'
import Toast from '../../../toast'
import { ThemeBuilder } from '../../embedded-chatbot/theme/theme-context'
import { ChatContextProvider } from '../context'
import Question from '../question'
// Global Mocks
vi.mock('@react-aria/interactions', () => ({

View File

@@ -1,7 +1,7 @@
import type { OnSend } from '../types'
import type { OnSend } from '../../types'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import TryToAsk from './try-to-ask'
import TryToAsk from '../try-to-ask'
describe('TryToAsk', () => {
const mockOnSend: OnSend = vi.fn()

View File

@@ -1,9 +1,9 @@
import type { ChatItem } from '../../types'
import type { ChatItem } from '../../../types'
import type { IThoughtProps } from '@/app/components/base/chat/chat/thought'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { MarkdownProps } from '@/app/components/base/markdown'
import { render, screen } from '@testing-library/react'
import AgentContent from './agent-content'
import AgentContent from '../agent-content'
// Mock Markdown component used only in tests
vi.mock('@/app/components/base/markdown', () => ({

View File

@@ -1,7 +1,7 @@
import type { ChatItem } from '../../types'
import type { ChatItem } from '../../../types'
import type { MarkdownProps } from '@/app/components/base/markdown'
import { render, screen } from '@testing-library/react'
import BasicContent from './basic-content'
import BasicContent from '../basic-content'
// Mock Markdown component used only in tests
vi.mock('@/app/components/base/markdown', () => ({

View File

@@ -1,7 +1,7 @@
import type { HumanInputFilledFormData } from '@/types/workflow'
import { render, screen } from '@testing-library/react'
import { describe, expect, it } from 'vitest'
import HumanInputFilledFormList from './human-input-filled-form-list'
import HumanInputFilledFormList from '../human-input-filled-form-list'
/**
* Type-safe factory.

Some files were not shown because too many files have changed in this diff Show More