Compare commits

..

2 Commits

Author SHA1 Message Date
-LAN-
d7d79f79f7 test: assert webhook debug error message 2026-03-17 21:11:37 +08:00
-LAN-
ac6d306ef8 Clarify webhook debug endpoint behavior 2026-03-17 21:05:51 +08:00
18 changed files with 330 additions and 481 deletions

View File

@@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
def handle_webhook_debug(webhook_id: str):
"""Handle webhook debug calls without triggering production workflow execution."""
"""Handle webhook debug calls without triggering production workflow execution.
The debug webhook endpoint is only for draft inspection flows. It never enqueues
Celery work for the published workflow; instead it dispatches an in-memory debug
event to an active Variable Inspector listener. Returning a clear error when no
listener is registered prevents a misleading 200 response for requests that are
effectively dropped.
"""
try:
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
if error:
@@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
"method": webhook_data.get("method"),
},
)
TriggerDebugEventBus.dispatch(
dispatch_count = TriggerDebugEventBus.dispatch(
tenant_id=webhook_trigger.tenant_id,
event=event,
pool_key=pool_key,
)
if dispatch_count == 0:
logger.warning(
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
webhook_trigger.webhook_id,
webhook_trigger.tenant_id,
webhook_trigger.app_id,
webhook_trigger.node_id,
)
return (
jsonify(
{
"error": "No active debug listener",
"message": (
"The webhook debug URL only works while the Variable Inspector is listening. "
"Use the published webhook URL to execute the workflow in Celery."
),
"execution_url": webhook_trigger.webhook_url,
}
),
409,
)
response_data, status_code = WebhookService.generate_webhook_response(node_config)
return jsonify(response_data), status_code

View File

@@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
continue
result.append(self.organize_agent_user_prompt(message))
agent_thoughts = message.agent_thoughts
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tool_names_raw = agent_thought.tool

View File

@@ -1,36 +1,13 @@
from collections.abc import Mapping
from typing import Any, TypedDict
from typing import Any
from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
class SystemParametersDict(TypedDict):
image_file_size_limit: int
video_file_size_limit: int
audio_file_size_limit: int
file_size_limit: int
workflow_file_upload_limit: int
class AppParametersDict(TypedDict):
opening_statement: str | None
suggested_questions: list[str]
suggested_questions_after_answer: dict[str, Any]
speech_to_text: dict[str, Any]
text_to_speech: dict[str, Any]
retriever_resource: dict[str, Any]
annotation_reply: dict[str, Any]
more_like_this: dict[str, Any]
user_input_form: list[dict[str, Any]]
sensitive_word_avoidance: dict[str, Any]
file_upload: dict[str, Any]
system_parameters: SystemParametersDict
def get_parameters_from_feature_dict(
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
) -> AppParametersDict:
) -> Mapping[str, Any]:
"""
Mapping from feature dict to webapp parameters
"""

View File

@@ -3,7 +3,7 @@ import time
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, TypedDict, Union
from typing import Any, NewType, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -76,20 +76,6 @@ NodeExecutionId = NewType("NodeExecutionId", str)
logger = logging.getLogger(__name__)
class AccountCreatedByDict(TypedDict):
id: str
name: str
email: str
class EndUserCreatedByDict(TypedDict):
id: str
user: str
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
@@ -263,19 +249,19 @@ class WorkflowResponseConverter:
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: CreatedByDict | dict[str, object] = {}
created_by: Mapping[str, object] | None
user = self._user
if isinstance(user, Account):
created_by = AccountCreatedByDict(
id=user.id,
name=user.name,
email=user.email,
)
elif isinstance(user, EndUser):
created_by = EndUserCreatedByDict(
id=user.id,
user=user.session_id,
)
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,

View File

@@ -1,5 +1,3 @@
from typing import TypedDict
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
@@ -8,20 +6,7 @@ from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
class MessageFileInfoDict(TypedDict):
related_id: str
extension: str
filename: str
size: int
mime_type: str
transfer_method: str
type: str
url: str
upload_file_id: str
remote_url: str | None
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.

View File

@@ -177,11 +177,13 @@ class Account(UserMixin, TypeBase):
@classmethod
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = db.session.execute(
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
).scalar_one_or_none()
account_integrate = (
db.session.query(AccountIntegrate)
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
if account_integrate:
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']

View File

@@ -8,7 +8,6 @@ import os
import pickle
import re
import time
from collections.abc import Sequence
from datetime import datetime
from json import JSONDecodeError
from typing import Any, TypedDict, cast
@@ -146,25 +145,30 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def total_available_documents(self):
return (
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
or 0
.scalar()
)
@property
def dataset_keyword_table(self):
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
dataset_keyword_table = (
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
return None
@property
def index_struct_dict(self):
@@ -191,66 +195,64 @@ class Dataset(Base):
@property
def latest_process_rule(self):
return db.session.scalar(
select(DatasetProcessRule)
return (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.first()
)
@property
def app_count(self):
return (
db.session.scalar(
select(func.count(AppDatasetJoin.id)).where(
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
)
)
or 0
db.session.query(func.count(AppDatasetJoin.id))
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
)
@property
def document_count(self):
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def available_document_count(self):
return (
db.session.scalar(
select(func.count(Document.id)).where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
db.session.query(func.count(Document.id))
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
or 0
.scalar()
)
@property
def available_segment_count(self):
return (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
db.session.query(func.count(DocumentSegment.id))
.where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
or 0
.scalar()
)
@property
def word_count(self):
return db.session.scalar(
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
.where(Document.dataset_id == self.id)
.scalar()
)
@property
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@@ -268,8 +270,8 @@ class Dataset(Base):
@property
def tags(self):
tags = db.session.scalars(
select(Tag)
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@@ -277,7 +279,8 @@ class Dataset(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "knowledge",
)
).all()
.all()
)
return tags or []
@@ -285,8 +288,8 @@ class Dataset(Base):
def external_knowledge_info(self):
if self.provider != "external":
return None
external_knowledge_binding = db.session.scalar(
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
)
if not external_knowledge_binding:
return None
@@ -307,7 +310,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False
@@ -518,8 +521,10 @@ class Document(Base):
if self.data_source_info:
if self.data_source_type == "upload_file":
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
file_detail = db.session.scalar(
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
)
if file_detail:
return {
@@ -552,23 +557,24 @@ class Document(Base):
@property
def dataset(self):
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
@property
def segment_count(self):
return (
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
)
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return db.session.scalar(
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.where(DocumentSegment.document_id == self.id)
.scalar()
)
@property
def uploader(self):
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
user = db.session.query(Account).where(Account.id == self.created_by).first()
return user.name if user else None
@property
@@ -582,13 +588,14 @@ class Document(Base):
@property
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
if self.doc_metadata:
document_metadatas = db.session.scalars(
select(DatasetMetadata)
document_metadatas = (
db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
.where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
).all()
.all()
)
metadata_list: list[DocMetadataDetailItem] = []
for metadata in document_metadatas:
metadata_dict: DocMetadataDetailItem = {
@@ -819,7 +826,7 @@ class DocumentSegment(Base):
)
@property
def child_chunks(self) -> Sequence[Any]:
def child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@@ -828,13 +835,16 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
def get_child_chunks(self) -> Sequence[Any]:
def get_child_chunks(self) -> list[Any]:
if not self.document:
return []
process_rule = self.document.dataset_process_rule
@@ -843,9 +853,12 @@ class DocumentSegment(Base):
if rules_dict:
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = db.session.scalars(
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
).all()
child_chunks = (
db.session.query(ChildChunk)
.where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
return []
@@ -994,15 +1007,15 @@ class ChildChunk(Base):
@property
def dataset(self):
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
@property
def document(self):
return db.session.scalar(select(Document).where(Document.id == self.document_id))
return db.session.query(Document).where(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(TypeBase):
@@ -1063,7 +1076,7 @@ class DatasetQuery(TypeBase):
if isinstance(queries, list):
for query in queries:
if query["content_type"] == QueryType.IMAGE_QUERY:
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
if file_info:
query["file_info"] = {
"id": file_info.id,
@@ -1128,7 +1141,7 @@ class DatasetKeywordTable(TypeBase):
super().__init__(object_hook=object_hook, *args, **kwargs)
# get dataset
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset:
return None
if self.data_source_type == "database":
@@ -1522,7 +1535,7 @@ class PipelineCustomizedTemplate(TypeBase):
@property
def created_user_name(self):
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
@@ -1557,7 +1570,7 @@ class Pipeline(TypeBase):
)
def retrieve_dataset(self, session: Session):
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(TypeBase):

View File

@@ -96,213 +96,3 @@ class ConversationStatus(StrEnum):
"""Conversation Status Enum"""
NORMAL = "normal"
class DataSourceType(StrEnum):
"""Data Source Type for Dataset and Document"""
UPLOAD_FILE = "upload_file"
NOTION_IMPORT = "notion_import"
WEBSITE_CRAWL = "website_crawl"
LOCAL_FILE = "local_file"
ONLINE_DOCUMENT = "online_document"
class ProcessRuleMode(StrEnum):
"""Dataset Process Rule Mode"""
AUTOMATIC = "automatic"
CUSTOM = "custom"
HIERARCHICAL = "hierarchical"
class IndexingStatus(StrEnum):
"""Document Indexing Status"""
WAITING = "waiting"
PARSING = "parsing"
CLEANING = "cleaning"
SPLITTING = "splitting"
INDEXING = "indexing"
PAUSED = "paused"
COMPLETED = "completed"
ERROR = "error"
class DocumentCreatedFrom(StrEnum):
"""Document Created From"""
WEB = "web"
API = "api"
RAG_PIPELINE = "rag-pipeline"
class ConversationFromSource(StrEnum):
"""Conversation / Message from_source"""
API = "api"
CONSOLE = "console"
class FeedbackFromSource(StrEnum):
"""MessageFeedback from_source"""
USER = "user"
ADMIN = "admin"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""
SERVICE_API = "service-api"
WEB_APP = "web-app"
TRIGGER = "trigger"
EXPLORE = "explore"
DEBUGGER = "debugger"
PUBLISHED_PIPELINE = "published"
VALIDATION = "validation"
@classmethod
def value_of(cls, value: str) -> "InvokeFrom":
return cls(value)
def to_source(self) -> str:
source_mapping = {
InvokeFrom.WEB_APP: "web_app",
InvokeFrom.DEBUGGER: "dev",
InvokeFrom.EXPLORE: "explore_app",
InvokeFrom.TRIGGER: "trigger",
InvokeFrom.SERVICE_API: "api",
}
return source_mapping.get(self, "dev")
class DocumentDocType(StrEnum):
"""Document doc_type classification"""
BOOK = "book"
WEB_PAGE = "web_page"
PAPER = "paper"
SOCIAL_MEDIA_POST = "social_media_post"
WIKIPEDIA_ENTRY = "wikipedia_entry"
PERSONAL_DOCUMENT = "personal_document"
BUSINESS_DOCUMENT = "business_document"
IM_CHAT_LOG = "im_chat_log"
SYNCED_FROM_NOTION = "synced_from_notion"
SYNCED_FROM_GITHUB = "synced_from_github"
OTHERS = "others"
class TagType(StrEnum):
"""Tag type"""
KNOWLEDGE = "knowledge"
APP = "app"
class DatasetMetadataType(StrEnum):
"""Dataset metadata value type"""
STRING = "string"
NUMBER = "number"
TIME = "time"
class SegmentStatus(StrEnum):
"""Document segment status"""
WAITING = "waiting"
INDEXING = "indexing"
COMPLETED = "completed"
ERROR = "error"
class DatasetRuntimeMode(StrEnum):
"""Dataset runtime mode"""
GENERAL = "general"
RAG_PIPELINE = "rag_pipeline"
class CollectionBindingType(StrEnum):
"""Dataset collection binding type"""
DATASET = "dataset"
ANNOTATION = "annotation"
class DatasetQuerySource(StrEnum):
"""Dataset query source"""
HIT_TESTING = "hit_testing"
APP = "app"
class TidbAuthBindingStatus(StrEnum):
"""TiDB auth binding status"""
CREATING = "CREATING"
ACTIVE = "ACTIVE"
class MessageFileBelongsTo(StrEnum):
"""MessageFile belongs_to"""
USER = "user"
ASSISTANT = "assistant"
class CredentialSourceType(StrEnum):
"""Load balancing credential source type"""
PROVIDER = "provider"
CUSTOM_MODEL = "custom_model"
class PaymentStatus(StrEnum):
"""Provider order payment status"""
WAIT_PAY = "wait_pay"
PAID = "paid"
FAILED = "failed"
REFUNDED = "refunded"
class BannerStatus(StrEnum):
"""ExporleBanner status"""
ENABLED = "enabled"
DISABLED = "disabled"
class SummaryStatus(StrEnum):
"""Document segment summary status"""
NOT_STARTED = "not_started"
GENERATING = "generating"
COMPLETED = "completed"
ERROR = "error"
class MessageChainType(StrEnum):
"""Message chain type"""
SYSTEM = "system"
class ProviderQuotaType(StrEnum):
PAID = "paid"
"""hosted paid quota"""
FREE = "free"
"""third-party free quota"""
TRIAL = "trial"
"""hosted trial quota"""
@staticmethod
def value_of(value: str) -> "ProviderQuotaType":
for member in ProviderQuotaType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")

View File

@@ -380,12 +380,13 @@ class App(Base):
@property
def site(self) -> Site | None:
return db.session.scalar(select(Site).where(Site.app_id == self.id))
site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self) -> AppModelConfig | None:
if self.app_model_config_id:
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@@ -394,7 +395,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return None
@@ -404,7 +405,8 @@ class App(Base):
@property
def tenant(self) -> Tenant | None:
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@property
def is_agent(self) -> bool:
@@ -544,9 +546,9 @@ class App(Base):
return deleted_tools
@property
def tags(self) -> Sequence[Tag]:
tags = db.session.scalars(
select(Tag)
def tags(self) -> list[Tag]:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
.where(
TagBinding.target_id == self.id,
@@ -554,14 +556,15 @@ class App(Base):
Tag.tenant_id == self.tenant_id,
Tag.type == "app",
)
).all()
.all()
)
return tags or []
@property
def author_name(self) -> str | None:
if self.created_by:
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
@@ -613,7 +616,8 @@ class AppModelConfig(TypeBase):
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def model_dict(self) -> ModelConfig:
@@ -648,8 +652,8 @@ class AppModelConfig(TypeBase):
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@@ -841,7 +845,8 @@ class RecommendedApp(Base): # bug
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class InstalledApp(TypeBase):
@@ -868,11 +873,13 @@ class InstalledApp(TypeBase):
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def tenant(self) -> Tenant | None:
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
class TrialApp(Base):
@@ -892,7 +899,8 @@ class TrialApp(Base):
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
@@ -911,11 +919,13 @@ class AccountTrialAppRecord(Base):
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.account_id))
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(TypeBase):
@@ -1107,8 +1117,8 @@ class Conversation(Base):
else:
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
else:
app_model_config = db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
app_model_config = (
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
)
if app_model_config:
model_config = app_model_config.to_dict()
@@ -1131,43 +1141,36 @@ class Conversation(Base):
@property
def annotated(self):
return (
db.session.scalar(
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
)
or 0
) > 0
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
@property
def annotation(self):
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
@property
def message_count(self):
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
return db.session.query(Message).where(Message.conversation_id == self.id).count()
@property
def user_feedback_stats(self):
like = (
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
)
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
)
or 0
.count()
)
dislike = (
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
)
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
)
or 0
.count()
)
return {"like": like, "dislike": dislike}
@@ -1175,25 +1178,23 @@ class Conversation(Base):
@property
def admin_feedback_stats(self):
like = (
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
)
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
)
or 0
.count()
)
dislike = (
db.session.scalar(
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
)
db.session.query(MessageFeedback)
.where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
)
or 0
.count()
)
return {"like": like, "dislike": dislike}
@@ -1255,19 +1256,22 @@ class Conversation(Base):
@property
def first_message(self):
return db.session.scalar(
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
return (
db.session.query(Message)
.where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
)
@property
def app(self) -> App | None:
with Session(db.engine, expire_on_commit=False) as session:
return session.scalar(select(App).where(App.id == self.app_id))
return session.query(App).where(App.id == self.app_id).first()
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
if end_user:
return end_user.session_id
@@ -1276,7 +1280,7 @@ class Conversation(Base):
@property
def from_account_name(self) -> str | None:
if self.from_account_id:
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account:
return account.name
@@ -1501,15 +1505,21 @@ class Message(Base):
@property
def user_feedback(self):
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
)
return feedback
@property
def admin_feedback(self):
return db.session.scalar(
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
feedback = (
db.session.query(MessageFeedback)
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
)
return feedback
@property
def feedbacks(self):
@@ -1518,27 +1528,28 @@ class Message(Base):
@property
def annotation(self):
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
return annotation
@property
def annotation_hit_history(self):
annotation_history = db.session.scalar(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
annotation_history = (
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
)
if annotation_history:
return db.session.scalar(
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
annotation = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
)
return annotation
return None
@property
def app_model_config(self):
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
if conversation:
return db.session.scalar(
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
)
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return None
@@ -1551,12 +1562,13 @@ class Message(Base):
return json.loads(self.message_metadata) if self.message_metadata else {}
@property
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
return db.session.scalars(
select(MessageAgentThought)
def agent_thoughts(self) -> list[MessageAgentThought]:
return (
db.session.query(MessageAgentThought)
.where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
).all()
.all()
)
@property
def retriever_resources(self) -> Any:
@@ -1567,7 +1579,7 @@ class Message(Base):
from factories import file_factory
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@@ -1731,7 +1743,8 @@ class MessageFeedback(TypeBase):
@property
def from_account(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self) -> MessageFeedbackDict:
return {
@@ -1804,11 +1817,13 @@ class MessageAnnotation(Base):
@property
def account(self):
return db.session.scalar(select(Account).where(Account.id == self.account_id))
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@property
def annotation_create_account(self):
return db.session.scalar(select(Account).where(Account.id == self.account_id))
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
class AppAnnotationHitHistory(TypeBase):
@@ -1837,15 +1852,18 @@ class AppAnnotationHitHistory(TypeBase):
@property
def account(self):
return db.session.scalar(
select(Account)
account = (
db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
.where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
return db.session.scalar(select(Account).where(Account.id == self.account_id))
account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
class AppAnnotationSetting(TypeBase):
@@ -1878,9 +1896,12 @@ class AppAnnotationSetting(TypeBase):
def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding
return db.session.scalar(
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
)
return collection_binding_detail
class OperationLog(TypeBase):
@@ -1986,9 +2007,7 @@ class AppMCPServer(TypeBase):
def generate_server_code(n: int) -> str:
while True:
result = generate_string(n)
while (
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
) > 0:
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
result = generate_string(n)
return result
@@ -2049,7 +2068,7 @@ class Site(Base):
def generate_code(n: int) -> str:
while True:
result = generate_string(n)
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
while db.session.query(Site).where(Site.code == result).count() > 0:
result = generate_string(n)
return result

View File

@@ -6,7 +6,7 @@ from functools import cached_property
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select, text
from sqlalchemy import DateTime, String, func, text
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
@@ -96,7 +96,7 @@ class Provider(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
@property
def credential_name(self):
@@ -159,8 +159,10 @@ class ProviderModel(TypeBase):
@cached_property
def credential(self):
if self.credential_id:
return db.session.scalar(
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
return (
db.session.query(ProviderModelCredential)
.where(ProviderModelCredential.id == self.credential_id)
.first()
)
@property

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, String, func, select
from sqlalchemy import ForeignKey, String, func
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
def user(self) -> Account | None:
if not self.user_id:
return None
return db.session.scalar(select(Account).where(Account.id == self.user_id))
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(TypeBase):
@@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
@property
def user(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.user_id))
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
return db.session.query(App).where(App.id == self.app_id).first()
class MCPToolProvider(TypeBase):
@@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
def load_user(self) -> Account | None:
return db.session.scalar(select(Account).where(Account.id == self.user_id))
return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def credentials(self) -> dict[str, Any]:

View File

@@ -2,7 +2,7 @@ from datetime import datetime
from uuid import uuid4
import sqlalchemy as sa
from sqlalchemy import DateTime, func, select
from sqlalchemy import DateTime, func
from sqlalchemy.orm import Mapped, mapped_column
from .base import TypeBase
@@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
@property
def message(self):
return db.session.scalar(select(Message).where(Message.id == self.message_id))
return db.session.query(Message).where(Message.id == self.message_id).first()
class PinnedConversation(TypeBase):

View File

@@ -679,14 +679,14 @@ class WorkflowRun(Base):
def message(self):
from .model import Message
return db.session.scalar(
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
return (
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
)
@property
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
def workflow(self):
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
def to_dict(self):
return {

View File

@@ -182,9 +182,4 @@ tasks/app_generate/workflow_execute_task.py
tasks/regenerate_summary_index_task.py
tasks/trigger_processing_tasks.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py
tasks/add_document_to_index_task.py
tasks/create_segment_to_index_task.py
tasks/disable_segment_from_index_task.py
tasks/enable_segment_to_index_task.py
tasks/remove_document_from_index_task.py
tasks/workflow_execution_tasks.py

View File

@@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from libs.login import current_user
from models import Account
from models.model import App, Conversation, EndUser, Message
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
class AgentService:
@@ -47,7 +47,7 @@ class AgentService:
if not message:
raise ValueError(f"Message not found: {message_id}")
agent_thoughts = message.agent_thoughts
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if conversation.from_end_user_id:
# only select name field

View File

@@ -23,6 +23,7 @@ def mock_jsonify():
class DummyWebhookTrigger:
webhook_id = "wh-1"
webhook_url = "http://localhost:5001/triggers/webhook/wh-1"
tenant_id = "tenant-1"
app_id = "app-1"
node_id = "node-1"
@@ -104,7 +105,32 @@ class TestHandleWebhookDebug:
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
@patch.object(module.TriggerDebugEventBus, "dispatch")
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0)
def test_debug_requires_active_listener(
self,
mock_dispatch,
mock_build_inputs,
mock_extract,
mock_get,
):
mock_get.return_value = (DummyWebhookTrigger(), None, "node_config")
mock_extract.return_value = {"method": "POST"}
response, status = module.handle_webhook_debug("wh-1")
assert status == 409
assert response["error"] == "No active debug listener"
assert response["message"] == (
"The webhook debug URL only works while the Variable Inspector is listening. "
"Use the published webhook URL to execute the workflow in Celery."
)
assert response["execution_url"] == DummyWebhookTrigger.webhook_url
mock_dispatch.assert_called_once()
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1)
@patch.object(module.WebhookService, "generate_webhook_response")
def test_debug_success(
self,

View File

@@ -622,10 +622,28 @@ class TestAccountGetByOpenId:
mock_account = Account(name="Test User", email="test@example.com")
mock_account.id = account_id
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
# Mock db.session.scalar() for Account lookup
mock_db.session.scalar.return_value = mock_account
# Mock the query chain
mock_query = MagicMock()
mock_where = MagicMock()
mock_where.one_or_none.return_value = mock_account_integrate
mock_query.where.return_value = mock_where
mock_db.session.query.return_value = mock_query
# Mock the second query for account
mock_account_query = MagicMock()
mock_account_where = MagicMock()
mock_account_where.one_or_none.return_value = mock_account
mock_account_query.where.return_value = mock_account_where
# Setup query to return different results based on model
def query_side_effect(model):
if model.__name__ == "AccountIntegrate":
return mock_query
elif model.__name__ == "Account":
return mock_account_query
return MagicMock()
mock_db.session.query.side_effect = query_side_effect
# Act
result = Account.get_by_openid(provider, open_id)
@@ -640,8 +658,12 @@ class TestAccountGetByOpenId:
provider = "github"
open_id = "github_user_456"
# Mock db.session.execute().scalar_one_or_none() to return None
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
# Mock the query chain to return None
mock_query = MagicMock()
mock_where = MagicMock()
mock_where.one_or_none.return_value = None
mock_query.where.return_value = mock_where
mock_db.session.query.return_value = mock_query
# Act
result = Account.get_by_openid(provider, open_id)

View File

@@ -300,8 +300,10 @@ class TestAppModelConfig:
created_by=str(uuid4()),
)
# Mock database scalar to return None (no annotation setting found)
with patch("models.model.db.session.scalar", return_value=None):
# Mock database query to return None
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
# Act
result = config.annotation_reply_dict
@@ -949,8 +951,10 @@ class TestSiteModel:
def test_site_generate_code(self):
"""Test Site.generate_code static method."""
# Mock database scalar to return 0 (no existing codes)
with patch("models.model.db.session.scalar", return_value=0):
# Mock database query to return 0 (no existing codes)
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.count.return_value = 0
# Act
code = Site.generate_code(8)