mirror of
https://github.com/langgenius/dify.git
synced 2026-03-17 21:37:03 +00:00
Compare commits
2 Commits
main
...
webhook-de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7d79f79f7 | ||
|
|
ac6d306ef8 |
@@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str):
|
||||
|
||||
@bp.route("/webhook-debug/<string:webhook_id>", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"])
|
||||
def handle_webhook_debug(webhook_id: str):
|
||||
"""Handle webhook debug calls without triggering production workflow execution."""
|
||||
"""Handle webhook debug calls without triggering production workflow execution.
|
||||
|
||||
The debug webhook endpoint is only for draft inspection flows. It never enqueues
|
||||
Celery work for the published workflow; instead it dispatches an in-memory debug
|
||||
event to an active Variable Inspector listener. Returning a clear error when no
|
||||
listener is registered prevents a misleading 200 response for requests that are
|
||||
effectively dropped.
|
||||
"""
|
||||
try:
|
||||
webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True)
|
||||
if error:
|
||||
@@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str):
|
||||
"method": webhook_data.get("method"),
|
||||
},
|
||||
)
|
||||
TriggerDebugEventBus.dispatch(
|
||||
dispatch_count = TriggerDebugEventBus.dispatch(
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
event=event,
|
||||
pool_key=pool_key,
|
||||
)
|
||||
if dispatch_count == 0:
|
||||
logger.warning(
|
||||
"Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)",
|
||||
webhook_trigger.webhook_id,
|
||||
webhook_trigger.tenant_id,
|
||||
webhook_trigger.app_id,
|
||||
webhook_trigger.node_id,
|
||||
)
|
||||
return (
|
||||
jsonify(
|
||||
{
|
||||
"error": "No active debug listener",
|
||||
"message": (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
),
|
||||
"execution_url": webhook_trigger.webhook_url,
|
||||
}
|
||||
),
|
||||
409,
|
||||
)
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
return jsonify(response_data), status_code
|
||||
|
||||
|
||||
@@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner):
|
||||
continue
|
||||
|
||||
result.append(self.organize_agent_user_prompt(message))
|
||||
agent_thoughts = message.agent_thoughts
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tool_names_raw = agent_thought.tool
|
||||
|
||||
@@ -1,36 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
class SystemParametersDict(TypedDict):
|
||||
image_file_size_limit: int
|
||||
video_file_size_limit: int
|
||||
audio_file_size_limit: int
|
||||
file_size_limit: int
|
||||
workflow_file_upload_limit: int
|
||||
|
||||
|
||||
class AppParametersDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: dict[str, Any]
|
||||
speech_to_text: dict[str, Any]
|
||||
text_to_speech: dict[str, Any]
|
||||
retriever_resource: dict[str, Any]
|
||||
annotation_reply: dict[str, Any]
|
||||
more_like_this: dict[str, Any]
|
||||
user_input_form: list[dict[str, Any]]
|
||||
sensitive_word_avoidance: dict[str, Any]
|
||||
file_upload: dict[str, Any]
|
||||
system_parameters: SystemParametersDict
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> AppParametersDict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, TypedDict, Union
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -76,20 +76,6 @@ NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountCreatedByDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
|
||||
|
||||
class EndUserCreatedByDict(TypedDict):
|
||||
id: str
|
||||
user: str
|
||||
|
||||
|
||||
CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _NodeSnapshot:
|
||||
"""In-memory cache for node metadata between start and completion events."""
|
||||
@@ -263,19 +249,19 @@ class WorkflowResponseConverter:
|
||||
outputs_mapping = graph_runtime_state.outputs or {}
|
||||
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
|
||||
|
||||
created_by: CreatedByDict | dict[str, object] = {}
|
||||
created_by: Mapping[str, object] | None
|
||||
user = self._user
|
||||
if isinstance(user, Account):
|
||||
created_by = AccountCreatedByDict(
|
||||
id=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
)
|
||||
elif isinstance(user, EndUser):
|
||||
created_by = EndUserCreatedByDict(
|
||||
id=user.id,
|
||||
user=user.session_id,
|
||||
)
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
@@ -8,20 +6,7 @@ from models.model import MessageFile, UploadFile
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
class MessageFileInfoDict(TypedDict):
|
||||
related_id: str
|
||||
extension: str
|
||||
filename: str
|
||||
size: int
|
||||
mime_type: str
|
||||
transfer_method: str
|
||||
type: str
|
||||
url: str
|
||||
upload_file_id: str
|
||||
remote_url: str | None
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict:
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
|
||||
@@ -177,11 +177,13 @@ class Account(UserMixin, TypeBase):
|
||||
|
||||
@classmethod
|
||||
def get_by_openid(cls, provider: str, open_id: str):
|
||||
account_integrate = db.session.execute(
|
||||
select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
).scalar_one_or_none()
|
||||
account_integrate = (
|
||||
db.session.query(AccountIntegrate)
|
||||
.where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if account_integrate:
|
||||
return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id))
|
||||
return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
|
||||
return None
|
||||
|
||||
# check current_user.current_tenant.current_role in ['admin', 'owner']
|
||||
|
||||
@@ -8,7 +8,6 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, TypedDict, cast
|
||||
@@ -146,25 +145,30 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
or 0
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id))
|
||||
dataset_keyword_table = (
|
||||
db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
|
||||
)
|
||||
if dataset_keyword_table:
|
||||
return dataset_keyword_table
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def index_struct_dict(self):
|
||||
@@ -191,66 +195,64 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def latest_process_rule(self):
|
||||
return db.session.scalar(
|
||||
select(DatasetProcessRule)
|
||||
return (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.dataset_id == self.id)
|
||||
.order_by(DatasetProcessRule.created_at.desc())
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def app_count(self):
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(AppDatasetJoin.id)).where(
|
||||
AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
db.session.query(func.count(AppDatasetJoin.id))
|
||||
.where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def document_count(self):
|
||||
return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def available_document_count(self):
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(Document.id)).where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
or 0
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def available_segment_count(self):
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
db.session.query(func.count(DocumentSegment.id))
|
||||
.where(
|
||||
DocumentSegment.dataset_id == self.id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
or 0
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def word_count(self):
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id)
|
||||
return (
|
||||
db.session.query(Document)
|
||||
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
|
||||
.where(Document.dataset_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1))
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
@@ -268,8 +270,8 @@ class Dataset(Base):
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@@ -277,7 +279,8 @@ class Dataset(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "knowledge",
|
||||
)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
|
||||
return tags or []
|
||||
|
||||
@@ -285,8 +288,8 @@ class Dataset(Base):
|
||||
def external_knowledge_info(self):
|
||||
if self.provider != "external":
|
||||
return None
|
||||
external_knowledge_binding = db.session.scalar(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id)
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
|
||||
)
|
||||
if not external_knowledge_binding:
|
||||
return None
|
||||
@@ -307,7 +310,7 @@ class Dataset(Base):
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id))
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
@@ -518,8 +521,10 @@ class Document(Base):
|
||||
if self.data_source_info:
|
||||
if self.data_source_type == "upload_file":
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
file_detail = db.session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
file_detail = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id == data_source_info_dict["upload_file_id"])
|
||||
.one_or_none()
|
||||
)
|
||||
if file_detail:
|
||||
return {
|
||||
@@ -552,23 +557,24 @@ class Document(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
|
||||
|
||||
@property
|
||||
def segment_count(self):
|
||||
return (
|
||||
db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0
|
||||
)
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
|
||||
|
||||
@property
|
||||
def hit_count(self):
|
||||
return db.session.scalar(
|
||||
select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id)
|
||||
return (
|
||||
db.session.query(DocumentSegment)
|
||||
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
|
||||
.where(DocumentSegment.document_id == self.id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def uploader(self):
|
||||
user = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
user = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
return user.name if user else None
|
||||
|
||||
@property
|
||||
@@ -582,13 +588,14 @@ class Document(Base):
|
||||
@property
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = db.session.scalars(
|
||||
select(DatasetMetadata)
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
|
||||
.where(
|
||||
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
|
||||
)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
@@ -819,7 +826,7 @@ class DocumentSegment(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def child_chunks(self) -> Sequence[Any]:
|
||||
def child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@@ -828,13 +835,16 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
def get_child_chunks(self) -> Sequence[Any]:
|
||||
def get_child_chunks(self) -> list[Any]:
|
||||
if not self.document:
|
||||
return []
|
||||
process_rule = self.document.dataset_process_rule
|
||||
@@ -843,9 +853,12 @@ class DocumentSegment(Base):
|
||||
if rules_dict:
|
||||
rules = Rule.model_validate(rules_dict)
|
||||
if rules.parent_mode:
|
||||
child_chunks = db.session.scalars(
|
||||
select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc())
|
||||
).all()
|
||||
child_chunks = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(ChildChunk.segment_id == self.id)
|
||||
.order_by(ChildChunk.position.asc())
|
||||
.all()
|
||||
)
|
||||
return child_chunks or []
|
||||
return []
|
||||
|
||||
@@ -994,15 +1007,15 @@ class ChildChunk(Base):
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
|
||||
|
||||
@property
|
||||
def document(self):
|
||||
return db.session.scalar(select(Document).where(Document.id == self.document_id))
|
||||
return db.session.query(Document).where(Document.id == self.document_id).first()
|
||||
|
||||
@property
|
||||
def segment(self):
|
||||
return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id))
|
||||
return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
|
||||
|
||||
|
||||
class AppDatasetJoin(TypeBase):
|
||||
@@ -1063,7 +1076,7 @@ class DatasetQuery(TypeBase):
|
||||
if isinstance(queries, list):
|
||||
for query in queries:
|
||||
if query["content_type"] == QueryType.IMAGE_QUERY:
|
||||
file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"]))
|
||||
file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first()
|
||||
if file_info:
|
||||
query["file_info"] = {
|
||||
"id": file_info.id,
|
||||
@@ -1128,7 +1141,7 @@ class DatasetKeywordTable(TypeBase):
|
||||
super().__init__(object_hook=object_hook, *args, **kwargs)
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
|
||||
dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
|
||||
if not dataset:
|
||||
return None
|
||||
if self.data_source_type == "database":
|
||||
@@ -1522,7 +1535,7 @@ class PipelineCustomizedTemplate(TypeBase):
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
@@ -1557,7 +1570,7 @@ class Pipeline(TypeBase):
|
||||
)
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id))
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(TypeBase):
|
||||
|
||||
@@ -96,213 +96,3 @@ class ConversationStatus(StrEnum):
|
||||
"""Conversation Status Enum"""
|
||||
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
class DataSourceType(StrEnum):
|
||||
"""Data Source Type for Dataset and Document"""
|
||||
|
||||
UPLOAD_FILE = "upload_file"
|
||||
NOTION_IMPORT = "notion_import"
|
||||
WEBSITE_CRAWL = "website_crawl"
|
||||
LOCAL_FILE = "local_file"
|
||||
ONLINE_DOCUMENT = "online_document"
|
||||
|
||||
|
||||
class ProcessRuleMode(StrEnum):
|
||||
"""Dataset Process Rule Mode"""
|
||||
|
||||
AUTOMATIC = "automatic"
|
||||
CUSTOM = "custom"
|
||||
HIERARCHICAL = "hierarchical"
|
||||
|
||||
|
||||
class IndexingStatus(StrEnum):
|
||||
"""Document Indexing Status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
PARSING = "parsing"
|
||||
CLEANING = "cleaning"
|
||||
SPLITTING = "splitting"
|
||||
INDEXING = "indexing"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DocumentCreatedFrom(StrEnum):
|
||||
"""Document Created From"""
|
||||
|
||||
WEB = "web"
|
||||
API = "api"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
class ConversationFromSource(StrEnum):
|
||||
"""Conversation / Message from_source"""
|
||||
|
||||
API = "api"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class FeedbackFromSource(StrEnum):
|
||||
"""MessageFeedback from_source"""
|
||||
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
PUBLISHED_PIPELINE = "published"
|
||||
VALIDATION = "validation"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "InvokeFrom":
|
||||
return cls(value)
|
||||
|
||||
def to_source(self) -> str:
|
||||
source_mapping = {
|
||||
InvokeFrom.WEB_APP: "web_app",
|
||||
InvokeFrom.DEBUGGER: "dev",
|
||||
InvokeFrom.EXPLORE: "explore_app",
|
||||
InvokeFrom.TRIGGER: "trigger",
|
||||
InvokeFrom.SERVICE_API: "api",
|
||||
}
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
class DocumentDocType(StrEnum):
|
||||
"""Document doc_type classification"""
|
||||
|
||||
BOOK = "book"
|
||||
WEB_PAGE = "web_page"
|
||||
PAPER = "paper"
|
||||
SOCIAL_MEDIA_POST = "social_media_post"
|
||||
WIKIPEDIA_ENTRY = "wikipedia_entry"
|
||||
PERSONAL_DOCUMENT = "personal_document"
|
||||
BUSINESS_DOCUMENT = "business_document"
|
||||
IM_CHAT_LOG = "im_chat_log"
|
||||
SYNCED_FROM_NOTION = "synced_from_notion"
|
||||
SYNCED_FROM_GITHUB = "synced_from_github"
|
||||
OTHERS = "others"
|
||||
|
||||
|
||||
class TagType(StrEnum):
|
||||
"""Tag type"""
|
||||
|
||||
KNOWLEDGE = "knowledge"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class DatasetMetadataType(StrEnum):
|
||||
"""Dataset metadata value type"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
TIME = "time"
|
||||
|
||||
|
||||
class SegmentStatus(StrEnum):
|
||||
"""Document segment status"""
|
||||
|
||||
WAITING = "waiting"
|
||||
INDEXING = "indexing"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DatasetRuntimeMode(StrEnum):
|
||||
"""Dataset runtime mode"""
|
||||
|
||||
GENERAL = "general"
|
||||
RAG_PIPELINE = "rag_pipeline"
|
||||
|
||||
|
||||
class CollectionBindingType(StrEnum):
|
||||
"""Dataset collection binding type"""
|
||||
|
||||
DATASET = "dataset"
|
||||
ANNOTATION = "annotation"
|
||||
|
||||
|
||||
class DatasetQuerySource(StrEnum):
|
||||
"""Dataset query source"""
|
||||
|
||||
HIT_TESTING = "hit_testing"
|
||||
APP = "app"
|
||||
|
||||
|
||||
class TidbAuthBindingStatus(StrEnum):
|
||||
"""TiDB auth binding status"""
|
||||
|
||||
CREATING = "CREATING"
|
||||
ACTIVE = "ACTIVE"
|
||||
|
||||
|
||||
class MessageFileBelongsTo(StrEnum):
|
||||
"""MessageFile belongs_to"""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
class CredentialSourceType(StrEnum):
|
||||
"""Load balancing credential source type"""
|
||||
|
||||
PROVIDER = "provider"
|
||||
CUSTOM_MODEL = "custom_model"
|
||||
|
||||
|
||||
class PaymentStatus(StrEnum):
|
||||
"""Provider order payment status"""
|
||||
|
||||
WAIT_PAY = "wait_pay"
|
||||
PAID = "paid"
|
||||
FAILED = "failed"
|
||||
REFUNDED = "refunded"
|
||||
|
||||
|
||||
class BannerStatus(StrEnum):
|
||||
"""ExporleBanner status"""
|
||||
|
||||
ENABLED = "enabled"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class SummaryStatus(StrEnum):
|
||||
"""Document segment summary status"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
GENERATING = "generating"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MessageChainType(StrEnum):
|
||||
"""Message chain type"""
|
||||
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = "paid"
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = "free"
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = "trial"
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> "ProviderQuotaType":
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
@@ -380,12 +380,13 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def site(self) -> Site | None:
|
||||
return db.session.scalar(select(Site).where(Site.app_id == self.id))
|
||||
site = db.session.query(Site).where(Site.app_id == self.id).first()
|
||||
return site
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> AppModelConfig | None:
|
||||
if self.app_model_config_id:
|
||||
return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id))
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
||||
return None
|
||||
|
||||
@@ -394,7 +395,7 @@ class App(Base):
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
return None
|
||||
|
||||
@@ -404,7 +405,8 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
@property
|
||||
def is_agent(self) -> bool:
|
||||
@@ -544,9 +546,9 @@ class App(Base):
|
||||
return deleted_tools
|
||||
|
||||
@property
|
||||
def tags(self) -> Sequence[Tag]:
|
||||
tags = db.session.scalars(
|
||||
select(Tag)
|
||||
def tags(self) -> list[Tag]:
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.where(
|
||||
TagBinding.target_id == self.id,
|
||||
@@ -554,14 +556,15 @@ class App(Base):
|
||||
Tag.tenant_id == self.tenant_id,
|
||||
Tag.type == "app",
|
||||
)
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def author_name(self) -> str | None:
|
||||
if self.created_by:
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.created_by))
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@@ -613,7 +616,8 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> ModelConfig:
|
||||
@@ -648,8 +652,8 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
@@ -841,7 +845,8 @@ class RecommendedApp(Base): # bug
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
||||
class InstalledApp(TypeBase):
|
||||
@@ -868,11 +873,13 @@ class InstalledApp(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
return tenant
|
||||
|
||||
|
||||
class TrialApp(Base):
|
||||
@@ -892,7 +899,8 @@ class TrialApp(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
|
||||
class AccountTrialAppRecord(Base):
|
||||
@@ -911,11 +919,13 @@ class AccountTrialAppRecord(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
return app
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
user = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return user
|
||||
|
||||
|
||||
class ExporleBanner(TypeBase):
|
||||
@@ -1107,8 +1117,8 @@ class Conversation(Base):
|
||||
else:
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
)
|
||||
if app_model_config:
|
||||
model_config = app_model_config.to_dict()
|
||||
@@ -1131,43 +1141,36 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def annotated(self):
|
||||
return (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id)
|
||||
)
|
||||
or 0
|
||||
) > 0
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1))
|
||||
return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
|
||||
|
||||
@property
|
||||
def message_count(self):
|
||||
return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0
|
||||
return db.session.query(Message).where(Message.conversation_id == self.id).count()
|
||||
|
||||
@property
|
||||
def user_feedback_stats(self):
|
||||
like = (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@@ -1175,25 +1178,23 @@ class Conversation(Base):
|
||||
@property
|
||||
def admin_feedback_stats(self):
|
||||
like = (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
|
||||
dislike = (
|
||||
db.session.scalar(
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
db.session.query(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
|
||||
return {"like": like, "dislike": dislike}
|
||||
@@ -1255,19 +1256,22 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def first_message(self):
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc())
|
||||
return (
|
||||
db.session.query(Message)
|
||||
.where(Message.conversation_id == self.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.scalar(select(App).where(App.id == self.app_id))
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
if self.from_end_user_id:
|
||||
end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id))
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
|
||||
if end_user:
|
||||
return end_user.session_id
|
||||
|
||||
@@ -1276,7 +1280,7 @@ class Conversation(Base):
|
||||
@property
|
||||
def from_account_name(self) -> str | None:
|
||||
if self.from_account_id:
|
||||
account = db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
if account:
|
||||
return account.name
|
||||
|
||||
@@ -1501,15 +1505,21 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def user_feedback(self):
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
|
||||
.first()
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def admin_feedback(self):
|
||||
return db.session.scalar(
|
||||
select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
feedback = (
|
||||
db.session.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
|
||||
.first()
|
||||
)
|
||||
return feedback
|
||||
|
||||
@property
|
||||
def feedbacks(self):
|
||||
@@ -1518,27 +1528,28 @@ class Message(Base):
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id))
|
||||
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
|
||||
return annotation
|
||||
|
||||
@property
|
||||
def annotation_hit_history(self):
|
||||
annotation_history = db.session.scalar(
|
||||
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id)
|
||||
annotation_history = (
|
||||
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
|
||||
)
|
||||
if annotation_history:
|
||||
return db.session.scalar(
|
||||
select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
annotation = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.id == annotation_history.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return annotation
|
||||
return None
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id))
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
|
||||
if conversation:
|
||||
return db.session.scalar(
|
||||
select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id)
|
||||
)
|
||||
return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
|
||||
|
||||
return None
|
||||
|
||||
@@ -1551,12 +1562,13 @@ class Message(Base):
|
||||
return json.loads(self.message_metadata) if self.message_metadata else {}
|
||||
|
||||
@property
|
||||
def agent_thoughts(self) -> Sequence[MessageAgentThought]:
|
||||
return db.session.scalars(
|
||||
select(MessageAgentThought)
|
||||
def agent_thoughts(self) -> list[MessageAgentThought]:
|
||||
return (
|
||||
db.session.query(MessageAgentThought)
|
||||
.where(MessageAgentThought.message_id == self.id)
|
||||
.order_by(MessageAgentThought.position.asc())
|
||||
).all()
|
||||
.all()
|
||||
)
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any:
|
||||
@@ -1567,7 +1579,7 @@ class Message(Base):
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
current_app = db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
current_app = db.session.query(App).where(App.id == self.app_id).first()
|
||||
if not current_app:
|
||||
raise ValueError(f"App {self.app_id} not found")
|
||||
|
||||
@@ -1731,7 +1743,8 @@ class MessageFeedback(TypeBase):
|
||||
|
||||
@property
|
||||
def from_account(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.from_account_id))
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
@@ -1804,11 +1817,13 @@ class MessageAnnotation(Base):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationHitHistory(TypeBase):
|
||||
@@ -1837,15 +1852,18 @@ class AppAnnotationHitHistory(TypeBase):
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
return db.session.scalar(
|
||||
select(Account)
|
||||
account = (
|
||||
db.session.query(Account)
|
||||
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
|
||||
.where(MessageAnnotation.id == self.annotation_id)
|
||||
.first()
|
||||
)
|
||||
return account
|
||||
|
||||
@property
|
||||
def annotation_create_account(self):
|
||||
return db.session.scalar(select(Account).where(Account.id == self.account_id))
|
||||
account = db.session.query(Account).where(Account.id == self.account_id).first()
|
||||
return account
|
||||
|
||||
|
||||
class AppAnnotationSetting(TypeBase):
|
||||
@@ -1878,9 +1896,12 @@ class AppAnnotationSetting(TypeBase):
|
||||
def collection_binding_detail(self):
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
return db.session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
collection_binding_detail = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == self.collection_binding_id)
|
||||
.first()
|
||||
)
|
||||
return collection_binding_detail
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
@@ -1986,9 +2007,7 @@ class AppMCPServer(TypeBase):
|
||||
def generate_server_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while (
|
||||
db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0
|
||||
) > 0:
|
||||
while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
@@ -2049,7 +2068,7 @@ class Site(Base):
|
||||
def generate_code(n: int) -> str:
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0:
|
||||
while db.session.query(Site).where(Site.code == result).count() > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
||||
@@ -6,7 +6,7 @@ from functools import cached_property
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy import DateTime, String, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
@@ -96,7 +96,7 @@ class Provider(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id))
|
||||
return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first()
|
||||
|
||||
@property
|
||||
def credential_name(self):
|
||||
@@ -159,8 +159,10 @@ class ProviderModel(TypeBase):
|
||||
@cached_property
|
||||
def credential(self):
|
||||
if self.credential_id:
|
||||
return db.session.scalar(
|
||||
select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id)
|
||||
return (
|
||||
db.session.query(ProviderModelCredential)
|
||||
.where(ProviderModelCredential.id == self.credential_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func, select
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -184,11 +184,11 @@ class ApiToolProvider(TypeBase):
|
||||
def user(self) -> Account | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
|
||||
|
||||
class ToolLabelBinding(TypeBase):
|
||||
@@ -262,11 +262,11 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
@@ -277,7 +277,7 @@ class WorkflowToolProvider(TypeBase):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.scalar(select(App).where(App.id == self.app_id))
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
|
||||
class MCPToolProvider(TypeBase):
|
||||
@@ -334,7 +334,7 @@ class MCPToolProvider(TypeBase):
|
||||
encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.scalar(select(Account).where(Account.id == self.user_id))
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, func, select
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import TypeBase
|
||||
@@ -38,7 +38,7 @@ class SavedMessage(TypeBase):
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
return db.session.scalar(select(Message).where(Message.id == self.message_id))
|
||||
return db.session.query(Message).where(Message.id == self.message_id).first()
|
||||
|
||||
|
||||
class PinnedConversation(TypeBase):
|
||||
|
||||
@@ -679,14 +679,14 @@ class WorkflowRun(Base):
|
||||
def message(self):
|
||||
from .model import Message
|
||||
|
||||
return db.session.scalar(
|
||||
select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id)
|
||||
return (
|
||||
db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
|
||||
)
|
||||
|
||||
@property
|
||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
||||
def workflow(self):
|
||||
return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id))
|
||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
||||
@@ -182,9 +182,4 @@ tasks/app_generate/workflow_execute_task.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/add_document_to_index_task.py
|
||||
tasks/create_segment_to_index_task.py
|
||||
tasks/disable_segment_from_index_task.py
|
||||
tasks/enable_segment_to_index_task.py
|
||||
tasks/remove_document_from_index_task.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
|
||||
|
||||
class AgentService:
|
||||
@@ -47,7 +47,7 @@ class AgentService:
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
agent_thoughts = message.agent_thoughts
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
|
||||
@@ -23,6 +23,7 @@ def mock_jsonify():
|
||||
|
||||
class DummyWebhookTrigger:
|
||||
webhook_id = "wh-1"
|
||||
webhook_url = "http://localhost:5001/triggers/webhook/wh-1"
|
||||
tenant_id = "tenant-1"
|
||||
app_id = "app-1"
|
||||
node_id = "node-1"
|
||||
@@ -104,7 +105,32 @@ class TestHandleWebhookDebug:
|
||||
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
|
||||
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch")
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0)
|
||||
def test_debug_requires_active_listener(
|
||||
self,
|
||||
mock_dispatch,
|
||||
mock_build_inputs,
|
||||
mock_extract,
|
||||
mock_get,
|
||||
):
|
||||
mock_get.return_value = (DummyWebhookTrigger(), None, "node_config")
|
||||
mock_extract.return_value = {"method": "POST"}
|
||||
|
||||
response, status = module.handle_webhook_debug("wh-1")
|
||||
|
||||
assert status == 409
|
||||
assert response["error"] == "No active debug listener"
|
||||
assert response["message"] == (
|
||||
"The webhook debug URL only works while the Variable Inspector is listening. "
|
||||
"Use the published webhook URL to execute the workflow in Celery."
|
||||
)
|
||||
assert response["execution_url"] == DummyWebhookTrigger.webhook_url
|
||||
mock_dispatch.assert_called_once()
|
||||
|
||||
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
|
||||
@patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1})
|
||||
@patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1)
|
||||
@patch.object(module.WebhookService, "generate_webhook_response")
|
||||
def test_debug_success(
|
||||
self,
|
||||
|
||||
@@ -622,10 +622,28 @@ class TestAccountGetByOpenId:
|
||||
mock_account = Account(name="Test User", email="test@example.com")
|
||||
mock_account.id = account_id
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate
|
||||
# Mock db.session.scalar() for Account lookup
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
# Mock the query chain
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = mock_account_integrate
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Mock the second query for account
|
||||
mock_account_query = MagicMock()
|
||||
mock_account_where = MagicMock()
|
||||
mock_account_where.one_or_none.return_value = mock_account
|
||||
mock_account_query.where.return_value = mock_account_where
|
||||
|
||||
# Setup query to return different results based on model
|
||||
def query_side_effect(model):
|
||||
if model.__name__ == "AccountIntegrate":
|
||||
return mock_query
|
||||
elif model.__name__ == "Account":
|
||||
return mock_account_query
|
||||
return MagicMock()
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
@@ -640,8 +658,12 @@ class TestAccountGetByOpenId:
|
||||
provider = "github"
|
||||
open_id = "github_user_456"
|
||||
|
||||
# Mock db.session.execute().scalar_one_or_none() to return None
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
# Mock the query chain to return None
|
||||
mock_query = MagicMock()
|
||||
mock_where = MagicMock()
|
||||
mock_where.one_or_none.return_value = None
|
||||
mock_query.where.return_value = mock_where
|
||||
mock_db.session.query.return_value = mock_query
|
||||
|
||||
# Act
|
||||
result = Account.get_by_openid(provider, open_id)
|
||||
|
||||
@@ -300,8 +300,10 @@ class TestAppModelConfig:
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
|
||||
# Mock database scalar to return None (no annotation setting found)
|
||||
with patch("models.model.db.session.scalar", return_value=None):
|
||||
# Mock database query to return None
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = config.annotation_reply_dict
|
||||
|
||||
@@ -949,8 +951,10 @@ class TestSiteModel:
|
||||
|
||||
def test_site_generate_code(self):
|
||||
"""Test Site.generate_code static method."""
|
||||
# Mock database scalar to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.scalar", return_value=0):
|
||||
# Mock database query to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.count.return_value = 0
|
||||
|
||||
# Act
|
||||
code = Site.generate_code(8)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user