mirror of
https://github.com/langgenius/dify.git
synced 2026-03-15 20:27:02 +00:00
Compare commits
42 Commits
copilot/re
...
copilot/an
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85396a2af2 | ||
|
|
6c19e75969 | ||
|
|
9970f4449a | ||
|
|
cbb19cce39 | ||
|
|
0aef09d630 | ||
|
|
d2208ad43e | ||
|
|
4a2ba058bb | ||
|
|
654e41d47f | ||
|
|
ec5409756e | ||
|
|
8b1ea3a8f5 | ||
|
|
f2d3feca66 | ||
|
|
0590b09958 | ||
|
|
66f9fde2fe | ||
|
|
1811a855ab | ||
|
|
322cd37de1 | ||
|
|
2cc0de9c1b | ||
|
|
46098b2be6 | ||
|
|
7dcf94f48f | ||
|
|
7869551afd | ||
|
|
c925d17e8f | ||
|
|
dc2a53d834 | ||
|
|
05ab107e73 | ||
|
|
c016793efb | ||
|
|
a5bcbaebb7 | ||
|
|
f50e44b24a | ||
|
|
09347d5e8b | ||
|
|
299a893ac5 | ||
|
|
c477571553 | ||
|
|
d01acfc490 | ||
|
|
f05f0be55f | ||
|
|
e74cda6535 | ||
|
|
0490756ab2 | ||
|
|
dc31b07533 | ||
|
|
d1eaa41dd1 | ||
|
|
7ffa6c1849 | ||
|
|
ad81513b6a | ||
|
|
f751864ab3 | ||
|
|
49dcf5e0d9 | ||
|
|
741d48560d | ||
|
|
6bd1be9e16 | ||
|
|
f76de73be4 | ||
|
|
98ba091a50 |
@@ -7,7 +7,7 @@ cd web && pnpm install
|
||||
pipx install uv
|
||||
|
||||
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
|
||||
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
|
||||
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
|
||||
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc
|
||||
|
||||
6
.github/dependabot.yml
vendored
6
.github/dependabot.yml
vendored
@@ -25,6 +25,10 @@ updates:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "storybook"
|
||||
@@ -33,5 +37,7 @@ updates:
|
||||
patterns:
|
||||
- "*"
|
||||
exclude-patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
5
Makefile
5
Makefile
@@ -68,8 +68,9 @@ lint:
|
||||
@echo "✅ Linting complete"
|
||||
|
||||
type-check:
|
||||
@echo "📝 Running type checks (basedpyright + mypy)..."
|
||||
@echo "📝 Running type checks (basedpyright + pyrefly + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@./dev/pyrefly-check-local
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
@@ -131,7 +132,7 @@ help:
|
||||
@echo " make format - Format code with ruff"
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, mypy)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases.
|
||||
|
||||
### Custom configurations
|
||||
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
|
||||
|
||||
#### Customizing Suggested Questions
|
||||
|
||||
|
||||
@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
|
||||
165
api/commands.py
165
api/commands.py
@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
@@ -2598,15 +2599,29 @@ def migrate_oss(
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
required=False,
|
||||
default=None,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
|
||||
)
|
||||
@click.option(
|
||||
"--before-days",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
|
||||
)
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
|
||||
@click.option(
|
||||
"--graceful-period",
|
||||
@@ -2618,8 +2633,10 @@ def migrate_oss(
|
||||
def clean_expired_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
from_days_ago: int | None,
|
||||
before_days: int | None,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
@@ -2630,18 +2647,70 @@ def clean_expired_messages(
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
abs_mode = start_from is not None and end_before is not None
|
||||
rel_mode = before_days is not None
|
||||
|
||||
if abs_mode and rel_mode:
|
||||
raise click.UsageError(
|
||||
"Options are mutually exclusive: use either (--start-from,--end-before) "
|
||||
"or (--from-days-ago,--before-days)."
|
||||
)
|
||||
|
||||
if from_days_ago is not None and before_days is None:
|
||||
raise click.UsageError("--from-days-ago must be used together with --before-days.")
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
|
||||
|
||||
if not abs_mode and not rel_mode:
|
||||
raise click.UsageError(
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
|
||||
)
|
||||
|
||||
if rel_mode:
|
||||
assert before_days is not None
|
||||
if before_days < 0:
|
||||
raise click.UsageError("--before-days must be >= 0.")
|
||||
if from_days_ago is not None:
|
||||
if from_days_ago < 0:
|
||||
raise click.UsageError("--from-days-ago must be >= 0.")
|
||||
if from_days_ago <= before_days:
|
||||
raise click.UsageError("--from-days-ago must be greater than --before-days.")
|
||||
|
||||
# Create policy based on billing configuration
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
# Create and run the cleanup service
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
assert end_before is not None
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
service = MessagesCleanService.from_days(
|
||||
policy=policy,
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
assert from_days_ago is not None
|
||||
now = naive_utc_now()
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=now - datetime.timedelta(days=from_days_ago),
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
end_at = time.perf_counter()
|
||||
@@ -2668,3 +2737,77 @@ def clean_expired_messages(
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.common import fields
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
@@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
@@ -33,14 +35,14 @@ class AppParameterApi(Resource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource):
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotCompletionAppError()
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import SensitiveWordAvoidanceEntity
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None:
|
||||
def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None:
|
||||
sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance")
|
||||
if not sensitive_word_avoidance_dict:
|
||||
return None
|
||||
@@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
if sensitive_word_avoidance_dict.get("enabled"):
|
||||
return SensitiveWordAvoidanceEntity(
|
||||
type=sensitive_word_avoidance_dict.get("type"),
|
||||
config=sensitive_word_avoidance_dict.get("config"),
|
||||
config=sensitive_word_avoidance_dict.get("config", {}),
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity
|
||||
from core.agent.prompt.template import REACT_PROMPT_TEMPLATES
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
|
||||
class AgentConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> AgentEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> AgentEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -28,17 +31,17 @@ class AgentConfigManager:
|
||||
|
||||
agent_tools = []
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) >= 4:
|
||||
if "enabled" not in tool or not tool["enabled"]:
|
||||
tool_dict = cast(dict[str, Any], tool)
|
||||
if len(tool_dict) >= 4:
|
||||
if "enabled" not in tool_dict or not tool_dict["enabled"]:
|
||||
continue
|
||||
|
||||
agent_tool_properties = {
|
||||
"provider_type": tool["provider_type"],
|
||||
"provider_id": tool["provider_id"],
|
||||
"tool_name": tool["tool_name"],
|
||||
"tool_parameters": tool.get("tool_parameters", {}),
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": tool_dict["provider_type"],
|
||||
"provider_id": tool_dict["provider_id"],
|
||||
"tool_name": tool_dict["tool_name"],
|
||||
"tool_parameters": tool_dict.get("tool_parameters", {}),
|
||||
"credential_id": tool_dict.get("credential_id", None),
|
||||
}
|
||||
|
||||
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
|
||||
@@ -47,7 +50,8 @@ class AgentConfigManager:
|
||||
"react_router",
|
||||
"router",
|
||||
}:
|
||||
agent_prompt = agent_dict.get("prompt", None) or {}
|
||||
agent_prompt_raw = agent_dict.get("prompt", None)
|
||||
agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {}
|
||||
# check model mode
|
||||
model_mode = config.get("model", {}).get("mode", "completion")
|
||||
if model_mode == "completion":
|
||||
@@ -75,7 +79,7 @@ class AgentConfigManager:
|
||||
strategy=strategy,
|
||||
prompt=agent_prompt_entity,
|
||||
tools=agent_tools,
|
||||
max_iteration=agent_dict.get("max_iteration", 10),
|
||||
max_iteration=cast(int, agent_dict.get("max_iteration", 10)),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
from typing import Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@@ -8,13 +8,13 @@ from core.app.app_config.entities import (
|
||||
ModelConfig,
|
||||
)
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
from services.dataset_service import DatasetService
|
||||
|
||||
|
||||
class DatasetConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> DatasetEntity | None:
|
||||
def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -25,11 +25,15 @@ class DatasetConfigManager:
|
||||
datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []})
|
||||
|
||||
for dataset in datasets.get("datasets", []):
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
keys = list(dataset.keys())
|
||||
if len(keys) == 0 or keys[0] != "dataset":
|
||||
continue
|
||||
|
||||
dataset = dataset["dataset"]
|
||||
if not isinstance(dataset, dict):
|
||||
continue
|
||||
|
||||
if "enabled" not in dataset or not dataset["enabled"]:
|
||||
continue
|
||||
@@ -47,15 +51,14 @@ class DatasetConfigManager:
|
||||
agent_dict = config.get("agent_mode", {})
|
||||
|
||||
for tool in agent_dict.get("tools", []):
|
||||
keys = tool.keys()
|
||||
if len(keys) == 1:
|
||||
if len(tool) == 1:
|
||||
# old standard
|
||||
key = list(tool.keys())[0]
|
||||
|
||||
if key != "dataset":
|
||||
continue
|
||||
|
||||
tool_item = tool[key]
|
||||
tool_item = cast(dict[str, Any], tool)[key]
|
||||
|
||||
if "enabled" not in tool_item or not tool_item["enabled"]:
|
||||
continue
|
||||
|
||||
@@ -5,12 +5,13 @@ from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from models.model import AppModelConfigDict
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
|
||||
class ModelConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> ModelConfigEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -22,7 +23,7 @@ class ModelConfigManager:
|
||||
if not model_config:
|
||||
raise ValueError("model is required")
|
||||
|
||||
completion_params = model_config.get("completion_params")
|
||||
completion_params = model_config.get("completion_params") or {}
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AdvancedChatMessageEntity,
|
||||
AdvancedChatPromptTemplateEntity,
|
||||
@@ -6,12 +8,12 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class PromptTemplateConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> PromptTemplateEntity:
|
||||
def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity:
|
||||
if not config.get("prompt_type"):
|
||||
raise ValueError("prompt_type is required")
|
||||
|
||||
@@ -40,14 +42,15 @@ class PromptTemplateConfigManager:
|
||||
advanced_completion_prompt_template = None
|
||||
completion_prompt_config = config.get("completion_prompt_config", {})
|
||||
if completion_prompt_config:
|
||||
completion_prompt_template_params = {
|
||||
completion_prompt_template_params: dict[str, Any] = {
|
||||
"prompt": completion_prompt_config["prompt"]["text"],
|
||||
}
|
||||
|
||||
if "conversation_histories_role" in completion_prompt_config:
|
||||
conv_role = completion_prompt_config.get("conversation_histories_role")
|
||||
if conv_role:
|
||||
completion_prompt_template_params["role_prefix"] = {
|
||||
"user": completion_prompt_config["conversation_histories_role"]["user_prefix"],
|
||||
"assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"],
|
||||
"user": conv_role["user_prefix"],
|
||||
"assistant": conv_role["assistant_prefix"],
|
||||
}
|
||||
|
||||
advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from dify_graph.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from models.model import AppModelConfigDict
|
||||
|
||||
_ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
[
|
||||
@@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset(
|
||||
|
||||
class BasicVariablesConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -51,7 +53,9 @@ class BasicVariablesConfigManager:
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=variable["variable"], type=variable["type"], config=variable["config"]
|
||||
variable=variable["variable"],
|
||||
type=variable.get("type", ""),
|
||||
config=variable.get("config", {}),
|
||||
)
|
||||
)
|
||||
elif variable_type in {
|
||||
@@ -64,10 +68,10 @@ class BasicVariablesConfigManager:
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=variable_type,
|
||||
variable=variable.get("variable"),
|
||||
type=cast(VariableEntityType, variable_type),
|
||||
variable=variable["variable"],
|
||||
description=variable.get("description") or "",
|
||||
label=variable.get("label"),
|
||||
label=variable["label"],
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
|
||||
@@ -281,7 +281,7 @@ class EasyUIBasedAppConfig(AppConfig):
|
||||
|
||||
app_model_config_from: EasyUIBasedAppModelConfigFrom
|
||||
app_model_config_id: str
|
||||
app_model_config_dict: dict
|
||||
app_model_config_dict: dict[str, Any]
|
||||
model: ModelConfigEntity
|
||||
prompt_template: PromptTemplateEntity
|
||||
dataset: DatasetEntity | None = None
|
||||
|
||||
@@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_partial_success_event(
|
||||
self,
|
||||
@@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
exceptions_count=event.exceptions_count,
|
||||
)
|
||||
|
||||
yield from self._handle_advanced_chat_message_end_event(
|
||||
QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
@@ -854,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueWorkflowSucceededEvent():
|
||||
yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueWorkflowPartialSuccessEvent():
|
||||
yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
break
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"]
|
||||
|
||||
@@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> AgentChatAppConfig:
|
||||
"""
|
||||
Convert app model config to agent chat app config
|
||||
@@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = AgentChatAppConfig(
|
||||
@@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]):
|
||||
def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for agent chat app model config
|
||||
|
||||
@@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@classmethod
|
||||
def validate_agent_mode_and_set_defaults(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor
|
||||
SuggestedQuestionsAfterAnswerConfigManager,
|
||||
)
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation
|
||||
|
||||
|
||||
class ChatAppConfig(EasyUIBasedAppConfig):
|
||||
@@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
conversation: Conversation | None = None,
|
||||
override_config_dict: dict | None = None,
|
||||
override_config_dict: AppModelConfigDict | None = None,
|
||||
) -> ChatAppConfig:
|
||||
"""
|
||||
Convert app model config to chat app config
|
||||
@@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for chat app model config
|
||||
|
||||
@@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner):
|
||||
memory=memory,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager
|
||||
@@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager
|
||||
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
from models.model import App, AppMode, AppModelConfig, AppModelConfigDict
|
||||
|
||||
|
||||
class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
@@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig):
|
||||
class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
@classmethod
|
||||
def get_app_config(
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None
|
||||
cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None
|
||||
) -> CompletionAppConfig:
|
||||
"""
|
||||
Convert app model config to completion app config
|
||||
@@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_model_config_dict = app_model_config.to_dict()
|
||||
config_dict = app_model_config_dict.copy()
|
||||
else:
|
||||
config_dict = override_config_dict or {}
|
||||
if not override_config_dict:
|
||||
raise Exception("override_config_dict is required when config_from is ARGS")
|
||||
config_dict = override_config_dict
|
||||
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
app_config = CompletionAppConfig(
|
||||
@@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=config_from,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=config_dict,
|
||||
app_model_config_dict=cast(dict[str, Any], config_dict),
|
||||
model=ModelConfigManager.convert(config=config_dict),
|
||||
prompt_template=PromptTemplateConfigManager.convert(config=config_dict),
|
||||
sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict),
|
||||
@@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
return app_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(cls, tenant_id: str, config: dict):
|
||||
def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict:
|
||||
"""
|
||||
Validate for completion app model config
|
||||
|
||||
@@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager):
|
||||
# Filter out extra parameters
|
||||
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||
|
||||
return filtered_config
|
||||
return cast(AppModelConfigDict, filtered_config)
|
||||
|
||||
@@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
raise ValueError("Message app_model_config is None")
|
||||
override_model_config_dict = app_model_config.to_dict()
|
||||
model_dict = override_model_config_dict["model"]
|
||||
completion_params = model_dict.get("completion_params")
|
||||
completion_params = model_dict.get("completion_params", {})
|
||||
completion_params["temperature"] = 0.9
|
||||
model_dict["completion_params"] = completion_params
|
||||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
@@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner):
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get(
|
||||
"enabled", False
|
||||
vision_enabled=bool(
|
||||
application_generate_entity.app_config.app_model_config_dict.get("file_upload", {})
|
||||
.get("image", {})
|
||||
.get("enabled", False)
|
||||
),
|
||||
)
|
||||
context_files = retrieved_files or []
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_file_utils import prepare_file_dict
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@@ -219,14 +218,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
tenant_id = self._application_generate_entity.app_config.tenant_id
|
||||
task_id = self._application_generate_entity.task_id
|
||||
publisher = None
|
||||
text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech")
|
||||
text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech"))
|
||||
if (
|
||||
text_to_speech_dict
|
||||
and text_to_speech_dict.get("autoPlay") == "enabled"
|
||||
and text_to_speech_dict.get("enabled")
|
||||
):
|
||||
publisher = AppGeneratorTTSPublisher(
|
||||
tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None)
|
||||
tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None)
|
||||
)
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
|
||||
# Fetch files associated with this message
|
||||
files = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
|
||||
if message_files:
|
||||
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
|
||||
upload_file_ids = list(
|
||||
dict.fromkeys(
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
)
|
||||
)
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
files_list = []
|
||||
for message_file in message_files:
|
||||
file_dict = prepare_file_dict(message_file, upload_files_map)
|
||||
files_list.append(file_dict)
|
||||
|
||||
files = files_list or None
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from threading import Thread
|
||||
from threading import Thread, Timer
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -96,9 +95,9 @@ class MessageCycleManager:
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
time.sleep(1)
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
thread = Timer(
|
||||
1,
|
||||
self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation_id,
|
||||
|
||||
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
:param message_file: MessageFile instance
|
||||
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
|
||||
:return: Dictionary containing file information
|
||||
"""
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
return {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class ChromaVector(BaseVector):
|
||||
self._client.get_or_create_collection(collection_name)
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
uuids = self._get_uuids(documents)
|
||||
texts = [d.page_content for d in documents]
|
||||
metadatas = [d.metadata for d in documents]
|
||||
@@ -73,6 +73,7 @@ class ChromaVector(BaseVector):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
# FIXME: chromadb using numpy array, fix the type error later
|
||||
collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore
|
||||
return uuids
|
||||
|
||||
def delete_by_metadata_field(self, key: str, value: str):
|
||||
collection = self._client.get_or_create_collection(self._collection_name)
|
||||
|
||||
@@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector):
|
||||
logger.warning("Failed to create inverted index: %s", e)
|
||||
# Continue without inverted index - full-text search will fall back to LIKE
|
||||
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
|
||||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]:
|
||||
"""Add documents with embeddings to the collection."""
|
||||
if not documents:
|
||||
return
|
||||
return []
|
||||
|
||||
batch_size = self._config.batch_size
|
||||
total_batches = (len(documents) + batch_size - 1) // batch_size
|
||||
added_ids = []
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : i + batch_size]
|
||||
batch_embeddings = embeddings[i : i + batch_size]
|
||||
batch_doc_ids = []
|
||||
for doc in batch_docs:
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))))
|
||||
added_ids.extend(batch_doc_ids)
|
||||
|
||||
# Execute batch insert through write queue
|
||||
self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches)
|
||||
self._execute_write(
|
||||
self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches
|
||||
)
|
||||
|
||||
return added_ids
|
||||
|
||||
def _insert_batch(
|
||||
self,
|
||||
batch_docs: list[Document],
|
||||
batch_embeddings: list[list[float]],
|
||||
batch_doc_ids: list[str],
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
total_batches: int,
|
||||
@@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector):
|
||||
data_rows = []
|
||||
vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768
|
||||
|
||||
for doc, embedding in zip(batch_docs, batch_embeddings):
|
||||
for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids):
|
||||
# Optimized: minimal checks for common case, fallback for edge cases
|
||||
metadata = doc.metadata or {}
|
||||
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
|
||||
doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))
|
||||
metadata = doc.metadata if isinstance(doc.metadata, dict) else {}
|
||||
|
||||
# Fast path for JSON serialization
|
||||
try:
|
||||
|
||||
@@ -194,6 +194,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||
|
||||
# Create a new database session
|
||||
with self._session_factory() as session:
|
||||
existing_model = session.get(WorkflowRun, db_model.id)
|
||||
if existing_model:
|
||||
if existing_model.tenant_id != self._tenant_id:
|
||||
raise ValueError("Unauthorized access to workflow run")
|
||||
# Preserve the original start time for pause/resume flows.
|
||||
db_model.created_at = existing_model.created_at
|
||||
|
||||
# SQLAlchemy merge intelligently handles both insert and update operations
|
||||
# based on the presence of the primary key
|
||||
session.merge(db_model)
|
||||
|
||||
@@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
|
||||
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import zipfile
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -82,8 +83,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
|
||||
value = variable.value
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
if isinstance(value, list):
|
||||
value = list(filter(lambda x: x, value))
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
if not value:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"text": ArrayStringSegment(value=[])},
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = [
|
||||
@@ -111,6 +122,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
else:
|
||||
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
|
||||
except DocumentExtractorError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
@@ -385,6 +397,32 @@ def parser_docx_part(block, doc: Document, content_items, i):
|
||||
content_items.append((i, "table", Table(block, doc)))
|
||||
|
||||
|
||||
def _normalize_docx_zip(file_content: bytes) -> bytes:
|
||||
"""
|
||||
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
|
||||
ZIP entry names use backslash (\\) as path separator instead of the forward
|
||||
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
|
||||
"word\\document.xml" is never found when python-docx looks for
|
||||
"word/document.xml", which triggers a KeyError about a missing relationship.
|
||||
|
||||
This function rewrites the ZIP in-memory, normalizing all entry names to
|
||||
use forward slashes without touching any actual document content.
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
|
||||
out_buf = io.BytesIO()
|
||||
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
|
||||
for item in zin.infolist():
|
||||
data = zin.read(item.filename)
|
||||
# Normalize backslash path separators to forward slash
|
||||
item.filename = item.filename.replace("\\", "/")
|
||||
zout.writestr(item, data)
|
||||
return out_buf.getvalue()
|
||||
except zipfile.BadZipFile:
|
||||
# Not a valid zip — return as-is and let python-docx report the real error
|
||||
return file_content
|
||||
|
||||
|
||||
def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOCX file.
|
||||
@@ -392,7 +430,15 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
try:
|
||||
doc = docx.Document(doc_file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
|
||||
# Some DOCX files exported by tools like Evernote on Windows use
|
||||
# backslash path separators in ZIP entries and/or single-quoted XML
|
||||
# attributes, both of which break python-docx on Linux. Normalize and retry.
|
||||
file_content = _normalize_docx_zip(file_content)
|
||||
doc = docx.Document(io.BytesIO(file_content))
|
||||
text = []
|
||||
|
||||
# Keep track of paragraph and table positions
|
||||
|
||||
@@ -23,7 +23,11 @@ from dify_graph.variables import (
|
||||
)
|
||||
from dify_graph.variables.segments import ArrayObjectSegment
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .entities import (
|
||||
Condition,
|
||||
KnowledgeRetrievalNodeData,
|
||||
MetadataFilteringCondition,
|
||||
)
|
||||
from .exc import (
|
||||
KnowledgeRetrievalNodeError,
|
||||
RateLimitExceededError,
|
||||
@@ -116,7 +120,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@@ -171,6 +175,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
|
||||
resolved_metadata_conditions = (
|
||||
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
|
||||
if node_data.metadata_filtering_conditions
|
||||
else None
|
||||
)
|
||||
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
@@ -189,7 +199,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
)
|
||||
@@ -247,7 +257,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_conditions=resolved_metadata_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
@@ -256,6 +266,48 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _resolve_metadata_filtering_conditions(
|
||||
self, conditions: MetadataFilteringCondition
|
||||
) -> MetadataFilteringCondition:
|
||||
if conditions.conditions is None:
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator,
|
||||
conditions=None,
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_conditions: list[Condition] = []
|
||||
for cond in conditions.conditions or []:
|
||||
value = cond.value
|
||||
if isinstance(value, str):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_value = segment_group.value[0].to_object()
|
||||
else:
|
||||
resolved_value = segment_group.text
|
||||
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
|
||||
resolved_values = []
|
||||
for v in value: # type: ignore
|
||||
segment_group = variable_pool.convert_template(v)
|
||||
if len(segment_group.value) == 1:
|
||||
resolved_values.append(segment_group.value[0].to_object())
|
||||
else:
|
||||
resolved_values.append(segment_group.text)
|
||||
resolved_value = resolved_values
|
||||
else:
|
||||
resolved_value = value
|
||||
resolved_conditions.append(
|
||||
Condition(
|
||||
name=cond.name,
|
||||
comparison_operator=cond.comparison_operator,
|
||||
value=resolved_value,
|
||||
)
|
||||
)
|
||||
return MetadataFilteringCondition(
|
||||
logical_operator=conditions.logical_operator or "and",
|
||||
conditions=resolved_conditions,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@@ -65,9 +65,15 @@ class VariablePool(BaseModel):
|
||||
# Add environment variables to the variable pool
|
||||
for var in self.environment_variables:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add conversation variables to the variable pool
|
||||
# Add conversation variables to the variable pool. When restoring from a serialized
|
||||
# snapshot, `variable_dictionary` already carries the latest runtime values.
|
||||
# In that case, keep existing entries instead of overwriting them with the
|
||||
# bootstrap list.
|
||||
for var in self.conversation_variables:
|
||||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
selector = (CONVERSATION_VARIABLE_NODE_ID, var.name)
|
||||
if self._has(selector):
|
||||
continue
|
||||
self.add(selector, var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
|
||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from events.app_event import app_model_config_was_updated
|
||||
@@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s
|
||||
continue
|
||||
|
||||
tool_type = list(tool.keys())[0]
|
||||
tool_config = list(tool.values())[0]
|
||||
tool_config = cast(dict[str, Any], list(tool.values())[0])
|
||||
if tool_type == "dataset":
|
||||
dataset_ids.add(tool_config.get("id"))
|
||||
dataset_id = tool_config.get("id")
|
||||
if isinstance(dataset_id, str):
|
||||
dataset_ids.add(dataset_id)
|
||||
|
||||
# get dataset from dataset_configs
|
||||
dataset_configs = app_model_config.dataset_configs_dict
|
||||
|
||||
@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -66,6 +66,7 @@ def run_migrations_offline():
|
||||
context.configure(
|
||||
url=url, target_metadata=get_metadata(), literal_binds=True
|
||||
)
|
||||
logger.info("Generating offline migration SQL with url: %s", url)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -15,6 +15,7 @@ from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
@@ -36,6 +37,259 @@ if TYPE_CHECKING:
|
||||
from .workflow import Workflow
|
||||
|
||||
|
||||
# --- TypedDict definitions for structured dict return types ---
|
||||
|
||||
|
||||
class EnabledConfig(TypedDict):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class EmbeddingModelInfo(TypedDict):
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class AnnotationReplyDisabledConfig(TypedDict):
|
||||
enabled: Literal[False]
|
||||
|
||||
|
||||
class AnnotationReplyEnabledConfig(TypedDict):
|
||||
id: str
|
||||
enabled: Literal[True]
|
||||
score_threshold: float
|
||||
embedding_model: EmbeddingModelInfo
|
||||
|
||||
|
||||
AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig
|
||||
|
||||
|
||||
class SensitiveWordAvoidanceConfig(TypedDict):
|
||||
enabled: bool
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class AgentToolConfig(TypedDict):
|
||||
provider_type: str
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any]
|
||||
plugin_unique_identifier: NotRequired[str | None]
|
||||
credential_id: NotRequired[str | None]
|
||||
|
||||
|
||||
class AgentModeConfig(TypedDict):
|
||||
enabled: bool
|
||||
strategy: str | None
|
||||
tools: list[AgentToolConfig | dict[str, Any]]
|
||||
prompt: str | None
|
||||
|
||||
|
||||
class ImageUploadConfig(TypedDict):
|
||||
enabled: bool
|
||||
number_limits: int
|
||||
detail: str
|
||||
transfer_methods: list[str]
|
||||
|
||||
|
||||
class FileUploadConfig(TypedDict):
|
||||
image: ImageUploadConfig
|
||||
|
||||
|
||||
class DeletedToolInfo(TypedDict):
|
||||
type: str
|
||||
tool_name: str
|
||||
provider_id: str
|
||||
|
||||
|
||||
class ExternalDataToolConfig(TypedDict):
|
||||
enabled: bool
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class UserInputFormItemConfig(TypedDict):
|
||||
variable: str
|
||||
label: str
|
||||
description: NotRequired[str]
|
||||
required: NotRequired[bool]
|
||||
max_length: NotRequired[int]
|
||||
options: NotRequired[list[str]]
|
||||
default: NotRequired[str]
|
||||
type: NotRequired[str]
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig}
|
||||
UserInputFormItem = dict[str, UserInputFormItemConfig]
|
||||
|
||||
|
||||
class DatasetConfigs(TypedDict):
|
||||
retrieval_model: str
|
||||
datasets: NotRequired[dict[str, Any]]
|
||||
top_k: NotRequired[int]
|
||||
score_threshold: NotRequired[float]
|
||||
score_threshold_enabled: NotRequired[bool]
|
||||
reranking_model: NotRequired[dict[str, Any] | None]
|
||||
weights: NotRequired[dict[str, Any] | None]
|
||||
reranking_enabled: NotRequired[bool]
|
||||
reranking_mode: NotRequired[str]
|
||||
metadata_filtering_mode: NotRequired[str]
|
||||
metadata_model_config: NotRequired[dict[str, Any] | None]
|
||||
metadata_filtering_conditions: NotRequired[dict[str, Any] | None]
|
||||
|
||||
|
||||
class ChatPromptMessage(TypedDict):
|
||||
text: str
|
||||
role: str
|
||||
|
||||
|
||||
class ChatPromptConfig(TypedDict, total=False):
|
||||
prompt: list[ChatPromptMessage]
|
||||
|
||||
|
||||
class CompletionPromptText(TypedDict):
|
||||
text: str
|
||||
|
||||
|
||||
class ConversationHistoriesRole(TypedDict):
|
||||
user_prefix: str
|
||||
assistant_prefix: str
|
||||
|
||||
|
||||
class CompletionPromptConfig(TypedDict):
|
||||
prompt: CompletionPromptText
|
||||
conversation_histories_role: NotRequired[ConversationHistoriesRole]
|
||||
|
||||
|
||||
class ModelConfig(TypedDict):
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
class AppModelConfigDict(TypedDict):
|
||||
opening_statement: str | None
|
||||
suggested_questions: list[str]
|
||||
suggested_questions_after_answer: EnabledConfig
|
||||
speech_to_text: EnabledConfig
|
||||
text_to_speech: EnabledConfig
|
||||
retriever_resource: EnabledConfig
|
||||
annotation_reply: AnnotationReplyConfig
|
||||
more_like_this: EnabledConfig
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceConfig
|
||||
external_data_tools: list[ExternalDataToolConfig]
|
||||
model: ModelConfig
|
||||
user_input_form: list[UserInputFormItem]
|
||||
dataset_query_variable: str | None
|
||||
pre_prompt: str | None
|
||||
agent_mode: AgentModeConfig
|
||||
prompt_type: str
|
||||
chat_prompt_config: ChatPromptConfig
|
||||
completion_prompt_config: CompletionPromptConfig
|
||||
dataset_configs: DatasetConfigs
|
||||
file_upload: FileUploadConfig
|
||||
# Added dynamically in Conversation.model_config
|
||||
model_id: NotRequired[str | None]
|
||||
provider: NotRequired[str | None]
|
||||
|
||||
|
||||
class ConversationDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
app_model_config_id: str | None
|
||||
model_provider: str | None
|
||||
override_model_configs: str | None
|
||||
model_id: str | None
|
||||
mode: str
|
||||
name: str
|
||||
summary: str | None
|
||||
inputs: dict[str, Any]
|
||||
introduction: str | None
|
||||
system_instruction: str | None
|
||||
system_instruction_tokens: int
|
||||
status: str
|
||||
invoke_from: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
read_at: datetime | None
|
||||
read_account_id: str | None
|
||||
dialogue_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class MessageDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
model_id: str | None
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
total_price: Decimal | None
|
||||
message: dict[str, Any]
|
||||
answer: str
|
||||
status: str
|
||||
error: str | None
|
||||
message_metadata: dict[str, Any]
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
agent_based: bool
|
||||
workflow_run_id: str | None
|
||||
|
||||
|
||||
class MessageFeedbackDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None
|
||||
from_source: str
|
||||
from_end_user_id: str | None
|
||||
from_account_id: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class MessageFileInfo(TypedDict, total=False):
|
||||
belongs_to: str | None
|
||||
upload_file_id: str | None
|
||||
id: str
|
||||
tenant_id: str
|
||||
type: str
|
||||
transfer_method: str
|
||||
remote_url: str | None
|
||||
related_id: str | None
|
||||
filename: str | None
|
||||
extension: str | None
|
||||
mime_type: str | None
|
||||
size: int
|
||||
dify_model_identity: str
|
||||
url: str | None
|
||||
|
||||
|
||||
class ExtraContentDict(TypedDict, total=False):
|
||||
type: str
|
||||
workflow_run_id: str
|
||||
|
||||
|
||||
class TraceAppConfigDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
tracing_provider: str | None
|
||||
tracing_config: dict[str, Any]
|
||||
is_active: bool
|
||||
created_at: str | None
|
||||
updated_at: str | None
|
||||
|
||||
|
||||
class DifySetup(TypeBase):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
@@ -176,7 +430,7 @@ class App(Base):
|
||||
return str(self.mode)
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
def deleted_tools(self) -> list[DeletedToolInfo]:
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
@@ -257,7 +511,7 @@ class App(Base):
|
||||
provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids)
|
||||
}
|
||||
|
||||
deleted_tools: list[dict[str, str]] = []
|
||||
deleted_tools: list[DeletedToolInfo] = []
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
@@ -364,35 +618,38 @@ class AppModelConfig(TypeBase):
|
||||
return app
|
||||
|
||||
@property
|
||||
def model_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.model) if self.model else {}
|
||||
def model_dict(self) -> ModelConfig:
|
||||
return cast(ModelConfig, json.loads(self.model) if self.model else {})
|
||||
|
||||
@property
|
||||
def suggested_questions_list(self) -> list[str]:
|
||||
return json.loads(self.suggested_questions) if self.suggested_questions else []
|
||||
|
||||
@property
|
||||
def suggested_questions_after_answer_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig,
|
||||
json.loads(self.suggested_questions_after_answer)
|
||||
if self.suggested_questions_after_answer
|
||||
else {"enabled": False}
|
||||
else {"enabled": False},
|
||||
)
|
||||
|
||||
@property
|
||||
def speech_to_text_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}
|
||||
def speech_to_text_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
|
||||
|
||||
@property
|
||||
def text_to_speech_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}
|
||||
def text_to_speech_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
|
||||
|
||||
@property
|
||||
def retriever_resource_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
def retriever_resource_dict(self) -> EnabledConfig:
|
||||
return cast(
|
||||
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
|
||||
)
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> dict[str, Any]:
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
|
||||
)
|
||||
@@ -415,56 +672,62 @@ class AppModelConfig(TypeBase):
|
||||
return {"enabled": False}
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
|
||||
|
||||
@property
|
||||
def sensitive_word_avoidance_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
|
||||
return cast(
|
||||
SensitiveWordAvoidanceConfig,
|
||||
json.loads(self.sensitive_word_avoidance)
|
||||
if self.sensitive_word_avoidance
|
||||
else {"enabled": False, "type": "", "configs": []}
|
||||
else {"enabled": False, "type": "", "config": {}},
|
||||
)
|
||||
|
||||
@property
|
||||
def external_data_tools_list(self) -> list[dict[str, Any]]:
|
||||
def external_data_tools_list(self) -> list[ExternalDataToolConfig]:
|
||||
return json.loads(self.external_data_tools) if self.external_data_tools else []
|
||||
|
||||
@property
|
||||
def user_input_form_list(self) -> list[dict[str, Any]]:
|
||||
def user_input_form_list(self) -> list[UserInputFormItem]:
|
||||
return json.loads(self.user_input_form) if self.user_input_form else []
|
||||
|
||||
@property
|
||||
def agent_mode_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def agent_mode_dict(self) -> AgentModeConfig:
|
||||
return cast(
|
||||
AgentModeConfig,
|
||||
json.loads(self.agent_mode)
|
||||
if self.agent_mode
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None}
|
||||
else {"enabled": False, "strategy": None, "tools": [], "prompt": None},
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}
|
||||
def chat_prompt_config_dict(self) -> ChatPromptConfig:
|
||||
return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {})
|
||||
|
||||
@property
|
||||
def completion_prompt_config_dict(self) -> dict[str, Any]:
|
||||
return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}
|
||||
def completion_prompt_config_dict(self) -> CompletionPromptConfig:
|
||||
return cast(
|
||||
CompletionPromptConfig,
|
||||
json.loads(self.completion_prompt_config) if self.completion_prompt_config else {},
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_configs_dict(self) -> dict[str, Any]:
|
||||
def dataset_configs_dict(self) -> DatasetConfigs:
|
||||
if self.dataset_configs:
|
||||
dataset_configs: dict[str, Any] = json.loads(self.dataset_configs)
|
||||
dataset_configs = json.loads(self.dataset_configs)
|
||||
if "retrieval_model" not in dataset_configs:
|
||||
return {"retrieval_model": "single"}
|
||||
else:
|
||||
return dataset_configs
|
||||
return cast(DatasetConfigs, dataset_configs)
|
||||
return {
|
||||
"retrieval_model": "multiple",
|
||||
}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict[str, Any]:
|
||||
return (
|
||||
def file_upload_dict(self) -> FileUploadConfig:
|
||||
return cast(
|
||||
FileUploadConfig,
|
||||
json.loads(self.file_upload)
|
||||
if self.file_upload
|
||||
else {
|
||||
@@ -474,10 +737,10 @@ class AppModelConfig(TypeBase):
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@@ -501,36 +764,42 @@ class AppModelConfig(TypeBase):
|
||||
"file_upload": self.file_upload_dict,
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: Mapping[str, Any]):
|
||||
def from_model_config_dict(self, model_config: AppModelConfigDict):
|
||||
self.opening_statement = model_config.get("opening_statement")
|
||||
self.suggested_questions = (
|
||||
json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None
|
||||
json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None
|
||||
)
|
||||
self.suggested_questions_after_answer = (
|
||||
json.dumps(model_config["suggested_questions_after_answer"])
|
||||
json.dumps(model_config.get("suggested_questions_after_answer"))
|
||||
if model_config.get("suggested_questions_after_answer")
|
||||
else None
|
||||
)
|
||||
self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None
|
||||
self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None
|
||||
self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None
|
||||
self.speech_to_text = (
|
||||
json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None
|
||||
)
|
||||
self.text_to_speech = (
|
||||
json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None
|
||||
)
|
||||
self.more_like_this = (
|
||||
json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None
|
||||
)
|
||||
self.sensitive_word_avoidance = (
|
||||
json.dumps(model_config["sensitive_word_avoidance"])
|
||||
json.dumps(model_config.get("sensitive_word_avoidance"))
|
||||
if model_config.get("sensitive_word_avoidance")
|
||||
else None
|
||||
)
|
||||
self.external_data_tools = (
|
||||
json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None
|
||||
json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None
|
||||
)
|
||||
self.model = json.dumps(model_config["model"]) if model_config.get("model") else None
|
||||
self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None
|
||||
self.user_input_form = (
|
||||
json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None
|
||||
json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None
|
||||
)
|
||||
self.dataset_query_variable = model_config.get("dataset_query_variable")
|
||||
self.pre_prompt = model_config["pre_prompt"]
|
||||
self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None
|
||||
self.pre_prompt = model_config.get("pre_prompt")
|
||||
self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None
|
||||
self.retriever_resource = (
|
||||
json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None
|
||||
json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None
|
||||
)
|
||||
self.prompt_type = model_config.get("prompt_type", "simple")
|
||||
self.chat_prompt_config = (
|
||||
@@ -823,24 +1092,26 @@ class Conversation(Base):
|
||||
self._inputs = inputs
|
||||
|
||||
@property
|
||||
def model_config(self):
|
||||
model_config = {}
|
||||
def model_config(self) -> AppModelConfigDict:
|
||||
model_config = cast(AppModelConfigDict, {})
|
||||
app_model_config: AppModelConfig | None = None
|
||||
|
||||
if self.mode == AppMode.ADVANCED_CHAT:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
model_config = override_model_configs
|
||||
model_config = cast(AppModelConfigDict, override_model_configs)
|
||||
else:
|
||||
if self.override_model_configs:
|
||||
override_model_configs = json.loads(self.override_model_configs)
|
||||
|
||||
if "model" in override_model_configs:
|
||||
# where is app_id?
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs)
|
||||
app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(
|
||||
cast(AppModelConfigDict, override_model_configs)
|
||||
)
|
||||
model_config = app_model_config.to_dict()
|
||||
else:
|
||||
model_config["configs"] = override_model_configs
|
||||
model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key]
|
||||
else:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
|
||||
@@ -1015,7 +1286,7 @@ class Conversation(Base):
|
||||
def in_debug_mode(self) -> bool:
|
||||
return self.override_model_configs is not None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ConversationDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@@ -1295,7 +1566,7 @@ class Message(Base):
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
def message_files(self) -> list[dict[str, Any]]:
|
||||
def message_files(self) -> list[MessageFileInfo]:
|
||||
from factories import file_factory
|
||||
|
||||
message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all()
|
||||
@@ -1350,10 +1621,13 @@ class Message(Base):
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
result: list[dict[str, Any]] = [
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
]
|
||||
result = cast(
|
||||
list[MessageFileInfo],
|
||||
[
|
||||
{"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()}
|
||||
for (file, message_file) in zip(files, message_files)
|
||||
],
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
return result
|
||||
@@ -1363,7 +1637,7 @@ class Message(Base):
|
||||
self._extra_contents = list(contents)
|
||||
|
||||
@property
|
||||
def extra_contents(self) -> list[dict[str, Any]]:
|
||||
def extra_contents(self) -> list[ExtraContentDict]:
|
||||
return getattr(self, "_extra_contents", [])
|
||||
|
||||
@property
|
||||
@@ -1379,7 +1653,7 @@ class Message(Base):
|
||||
|
||||
return None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
@@ -1403,7 +1677,7 @@ class Message(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Message:
|
||||
def from_dict(cls, data: MessageDict) -> Message:
|
||||
return cls(
|
||||
id=data["id"],
|
||||
app_id=data["app_id"],
|
||||
@@ -1463,7 +1737,7 @@ class MessageFeedback(TypeBase):
|
||||
account = db.session.query(Account).where(Account.id == self.from_account_id).first()
|
||||
return account
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> MessageFeedbackDict:
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"app_id": str(self.app_id),
|
||||
@@ -1726,8 +2000,8 @@ class AppMCPServer(TypeBase):
|
||||
return result
|
||||
|
||||
@property
|
||||
def parameters_dict(self) -> dict[str, Any]:
|
||||
return cast(dict[str, Any], json.loads(self.parameters))
|
||||
def parameters_dict(self) -> dict[str, str]:
|
||||
return cast(dict[str, str], json.loads(self.parameters))
|
||||
|
||||
|
||||
class Site(Base):
|
||||
@@ -2167,7 +2441,7 @@ class TraceAppConfig(TypeBase):
|
||||
def tracing_config_str(self) -> str:
|
||||
return json.dumps(self.tracing_config_dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> TraceAppConfigDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"app_id": self.app_id,
|
||||
|
||||
@@ -35,7 +35,7 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"markdown~=3.5.1",
|
||||
"markdown~=3.8.1",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
@@ -247,3 +247,13 @@ module = [
|
||||
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
[tool.pyrefly]
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
|
||||
200
api/pyrefly-local-excludes.txt
Normal file
200
api/pyrefly-local-excludes.txt
Normal file
@@ -0,0 +1,200 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
controllers/console/app/mcp_server.py
|
||||
controllers/console/app/site.py
|
||||
controllers/console/auth/email_register.py
|
||||
controllers/console/human_input_form.py
|
||||
controllers/console/init_validate.py
|
||||
controllers/console/ping.py
|
||||
controllers/console/setup.py
|
||||
controllers/console/version.py
|
||||
controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
core/llm_generator/llm_generator.py
|
||||
core/llm_generator/output_parser/structured_output.py
|
||||
core/mcp/mcp_client.py
|
||||
core/ops/aliyun_trace/data_exporter/traceclient.py
|
||||
core/ops/arize_phoenix_trace/arize_phoenix_trace.py
|
||||
core/ops/mlflow_trace/mlflow_trace.py
|
||||
core/ops/ops_trace_manager.py
|
||||
core/ops/tencent_trace/client.py
|
||||
core/ops/tencent_trace/utils.py
|
||||
core/plugin/backwards_invocation/base.py
|
||||
core/plugin/backwards_invocation/model.py
|
||||
core/prompt/utils/extract_thread_messages.py
|
||||
core/rag/datasource/keyword/jieba/jieba.py
|
||||
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
|
||||
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
|
||||
core/rag/datasource/vdb/baidu/baidu_vector.py
|
||||
core/rag/datasource/vdb/chroma/chroma_vector.py
|
||||
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
|
||||
core/rag/datasource/vdb/couchbase/couchbase_vector.py
|
||||
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
|
||||
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
|
||||
core/rag/datasource/vdb/lindorm/lindorm_vector.py
|
||||
core/rag/datasource/vdb/matrixone/matrixone_vector.py
|
||||
core/rag/datasource/vdb/milvus/milvus_vector.py
|
||||
core/rag/datasource/vdb/myscale/myscale_vector.py
|
||||
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
|
||||
core/rag/datasource/vdb/opensearch/opensearch_vector.py
|
||||
core/rag/datasource/vdb/oracle/oraclevector.py
|
||||
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
|
||||
core/rag/datasource/vdb/relyt/relyt_vector.py
|
||||
core/rag/datasource/vdb/tablestore/tablestore_vector.py
|
||||
core/rag/datasource/vdb/tencent/tencent_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
|
||||
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
|
||||
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
|
||||
core/rag/datasource/vdb/upstash/upstash_vector.py
|
||||
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
|
||||
core/rag/datasource/vdb/weaviate/weaviate_vector.py
|
||||
core/rag/extractor/csv_extractor.py
|
||||
core/rag/extractor/excel_extractor.py
|
||||
core/rag/extractor/firecrawl/firecrawl_app.py
|
||||
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
|
||||
core/rag/extractor/html_extractor.py
|
||||
core/rag/extractor/jina_reader_extractor.py
|
||||
core/rag/extractor/markdown_extractor.py
|
||||
core/rag/extractor/notion_extractor.py
|
||||
core/rag/extractor/pdf_extractor.py
|
||||
core/rag/extractor/text_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_doc_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_eml_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_epub_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_msg_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
|
||||
core/rag/extractor/unstructured/unstructured_xml_extractor.py
|
||||
core/rag/extractor/watercrawl/client.py
|
||||
core/rag/extractor/watercrawl/extractor.py
|
||||
core/rag/extractor/watercrawl/provider.py
|
||||
core/rag/extractor/word_extractor.py
|
||||
core/rag/index_processor/processor/paragraph_index_processor.py
|
||||
core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/rag/retrieval/router/multi_dataset_function_call_router.py
|
||||
core/rag/summary_index/summary_index.py
|
||||
core/repositories/sqlalchemy_workflow_execution_repository.py
|
||||
core/repositories/sqlalchemy_workflow_node_execution_repository.py
|
||||
core/tools/__base/tool.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/utils/message_transformer.py
|
||||
core/tools/utils/web_reader_tool.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
dify_graph/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
extensions/storage/aliyun_oss_storage.py
|
||||
extensions/storage/aws_s3_storage.py
|
||||
extensions/storage/azure_blob_storage.py
|
||||
extensions/storage/baidu_obs_storage.py
|
||||
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
|
||||
extensions/storage/clickzetta_volume/file_lifecycle.py
|
||||
extensions/storage/google_cloud_storage.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
extensions/storage/opendal_storage.py
|
||||
extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/auth/firecrawl/firecrawl.py
|
||||
services/auth/jina.py
|
||||
services/auth/jina/jina.py
|
||||
services/auth/watercrawl/watercrawl.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
|
||||
services/external_knowledge_service.py
|
||||
services/plugin/plugin_migration.py
|
||||
services/recommend_app/buildin/buildin_retrieval.py
|
||||
services/recommend_app/database/database_retrieval.py
|
||||
services/recommend_app/remote/remote_retrieval.py
|
||||
services/summary_index_service.py
|
||||
services/tools/tools_transform_service.py
|
||||
services/trigger/trigger_provider_service.py
|
||||
services/trigger/trigger_subscription_builder_service.py
|
||||
services/trigger/webhook_service.py
|
||||
services/workflow_draft_variable_service.py
|
||||
services/workflow_event_snapshot_service.py
|
||||
services/workflow_service.py
|
||||
tasks/app_generate/workflow_execute_task.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/trigger_processing_tasks.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
tasks/workflow_execution_tasks.py
|
||||
@@ -1,8 +0,0 @@
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
".venv",
|
||||
"migrations/",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
infer-with-first-use = false
|
||||
@@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
addopts = --cov=./api --cov-report=json
|
||||
pythonpath = .
|
||||
addopts = --cov=./api --cov-report=json --import-mode=importlib
|
||||
env =
|
||||
ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz
|
||||
AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com
|
||||
@@ -19,7 +20,7 @@ env =
|
||||
GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz
|
||||
HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu
|
||||
HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b
|
||||
HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a
|
||||
MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa
|
||||
MOCK_SWITCH = true
|
||||
|
||||
@@ -21,6 +21,10 @@ celery_redis = Redis(
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from celery import group
|
||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
if not any(acquired):
|
||||
continue
|
||||
|
||||
jobs = [
|
||||
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
|
||||
if is_locked
|
||||
]
|
||||
result = group(jobs).apply_async()
|
||||
enqueued = len(jobs)
|
||||
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
result,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from celery import group, shared_task
|
||||
from celery import current_app, group, shared_task
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
|
||||
with session_factory() as session:
|
||||
total_dispatched = 0
|
||||
|
||||
# Process in batches until we've handled all due schedules or hit the limit
|
||||
while True:
|
||||
due_schedules = _fetch_due_schedules(session)
|
||||
|
||||
if not due_schedules:
|
||||
break
|
||||
|
||||
dispatched_count = _process_schedules(session, due_schedules)
|
||||
total_dispatched += dispatched_count
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
dispatched_count = _process_schedules(session, due_schedules, producer)
|
||||
total_dispatched += dispatched_count
|
||||
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if (
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||
):
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
if total_dispatched > 0:
|
||||
logger.info("Total processed: %d dispatched", total_dispatched)
|
||||
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
|
||||
|
||||
|
||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
@@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
return list(due_schedules)
|
||||
|
||||
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
|
||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||
if not schedules:
|
||||
return 0
|
||||
@@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
|
||||
|
||||
if tasks_to_dispatch:
|
||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||
job.apply_async()
|
||||
job.apply_async(producer=producer)
|
||||
|
||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -32,7 +33,7 @@ from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, App, AppMode
|
||||
from models.model import AppModelConfig, IconType
|
||||
from models.model import AppModelConfig, AppModelConfigDict, IconType
|
||||
from models.workflow import Workflow
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||
@@ -523,7 +524,7 @@ class AppDslService:
|
||||
if not app.app_model_config:
|
||||
app_model_config = AppModelConfig(
|
||||
app_id=app.id, created_by=account.id, updated_by=account.id
|
||||
).from_model_config_dict(model_config)
|
||||
).from_model_config_dict(cast(AppModelConfigDict, model_config))
|
||||
app_model_config.id = str(uuid4())
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, AppModelConfigDict
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import TypedDict, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask_sqlalchemy.pagination import Pagination
|
||||
@@ -187,7 +187,7 @@ class AppService:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
agent_tool_entity = AgentToolEntity(**cast(dict[str, Any], tool))
|
||||
# get tool
|
||||
try:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
@@ -388,7 +388,7 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get("tools", [])
|
||||
tools = cast(list[dict[str, Any]], agent_config.get("tools", []))
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import io
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from werkzeug.datastructures import FileStorage
|
||||
@@ -106,7 +107,7 @@ class AudioService:
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get("voice")
|
||||
voice = cast(str | None, text_to_speech_dict.get("voice"))
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
|
||||
@@ -63,7 +63,12 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
304
api/services/retention/conversation/message_export_service.py
Normal file
304
api/services/retention/conversation/message_export_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@@ -142,7 +143,7 @@ class MessagesCleanService:
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
end_before = naive_utc_now() - datetime.timedelta(days=days)
|
||||
|
||||
logger.info(
|
||||
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from celery import current_app, shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryTaskLike(Protocol):
|
||||
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
@@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
|
||||
) -> None:
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
@@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
|
||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": document_task.tenant_id,
|
||||
"dataset_id": document_task.dataset_id,
|
||||
"document_ids": document_task.document_ids,
|
||||
},
|
||||
producer=producer,
|
||||
)
|
||||
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
@shared_task(queue="dataset_summary")
|
||||
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
|
||||
"""
|
||||
Async generate summary index for document segments.
|
||||
|
||||
@@ -6,7 +6,6 @@ import typing
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.helper.marketplace import record_install_plugin_event
|
||||
from core.plugin.entities.marketplace import MarketplacePluginSnapshot
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
@@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
# execute upgrade
|
||||
new_unique_identifier = manifest.latest_package_identifier
|
||||
|
||||
record_install_plugin_event(new_unique_identifier)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}",
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from celery import group, shared_task
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -27,6 +28,11 @@ from services.file_service import FileService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunked(iterable: Sequence, size: int):
|
||||
it = iter(iterable)
|
||||
return iter(lambda: list(islice(it, size)), [])
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
@@ -83,16 +89,24 @@ def rag_pipeline_run_task(
|
||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for batch in chunked(next_file_ids, 100):
|
||||
jobs = []
|
||||
for next_file_id in batch:
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
|
||||
file_id = (
|
||||
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
|
||||
)
|
||||
|
||||
jobs.append(
|
||||
rag_pipeline_run_task.s(
|
||||
rag_pipeline_invoke_entities_file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if jobs:
|
||||
group(jobs).apply_async()
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
@shared_task(queue="dataset_summary")
|
||||
def regenerate_summary_index_task(
|
||||
dataset_id: str,
|
||||
regenerate_reason: str = "summary_model_changed",
|
||||
|
||||
@@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement
|
||||
for App descriptions across all creation and editing endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the API root to Python path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
class TestAppDescriptionValidationUnit:
|
||||
"""Unit tests for description validation function"""
|
||||
|
||||
@@ -10,8 +10,11 @@ more reliable and realistic test scenarios.
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Protocol, TypeVar
|
||||
|
||||
import psycopg2
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
@@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CloserProtocol(Protocol):
|
||||
"""_Closer is any type which implement the close() method."""
|
||||
|
||||
def close(self):
|
||||
"""close the current object, release any external resouece (file, transaction, connection etc.)
|
||||
associated with it.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
_Closer = TypeVar("_Closer", bound=_CloserProtocol)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]:
|
||||
yield closer
|
||||
closer.close()
|
||||
|
||||
|
||||
class DifyTestContainers:
|
||||
"""
|
||||
Manages all test containers required for Dify integration tests.
|
||||
@@ -97,45 +119,28 @@ class DifyTestContainers:
|
||||
wait_for_logs(self.postgres, "is ready to accept connections", timeout=30)
|
||||
logger.info("PostgreSQL container is ready and accepting connections")
|
||||
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
try:
|
||||
import psycopg2
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
cursor.close()
|
||||
conn.close()
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
with _auto_close(conn):
|
||||
with conn.cursor() as cursor:
|
||||
# Install uuid-ossp extension for UUID generation
|
||||
logger.info("Installing uuid-ossp extension...")
|
||||
cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
|
||||
logger.info("uuid-ossp extension installed successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to install uuid-ossp extension: %s", e)
|
||||
|
||||
# Create plugin database for dify-plugin-daemon
|
||||
logger.info("Creating plugin database...")
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=self.postgres.username,
|
||||
password=self.postgres.password,
|
||||
database=self.postgres.dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("CREATE DATABASE dify_plugin;")
|
||||
cursor.close()
|
||||
conn.close()
|
||||
# NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement
|
||||
# inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block.
|
||||
with _auto_close(conn.cursor()) as cursor:
|
||||
# Create plugin database for dify-plugin-daemon
|
||||
logger.info("Creating plugin database...")
|
||||
cursor.execute("CREATE DATABASE dify_plugin;")
|
||||
logger.info("Plugin database created successfully")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create plugin database: %s", e)
|
||||
|
||||
# Set up storage environment variables
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
@@ -258,23 +263,16 @@ class DifyTestContainers:
|
||||
containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon]
|
||||
for container in containers:
|
||||
if container:
|
||||
try:
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the test cleanup
|
||||
logger.warning("Failed to stop container %s: %s", container, e)
|
||||
container_name = container.image
|
||||
logger.info("Stopping container: %s", container_name)
|
||||
container.stop()
|
||||
logger.info("Successfully stopped container: %s", container_name)
|
||||
|
||||
# Stop and remove the network
|
||||
if self.network:
|
||||
try:
|
||||
logger.info("Removing Docker network...")
|
||||
self.network.remove()
|
||||
logger.info("Successfully removed Docker network")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove Docker network: %s", e)
|
||||
logger.info("Removing Docker network...")
|
||||
self.network.remove()
|
||||
logger.info("Successfully removed Docker network")
|
||||
|
||||
self._containers_started = False
|
||||
logger.info("All test containers stopped and cleaned up successfully")
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
"""
|
||||
Container-backed integration tests for dataset permission services on the real SQL path.
|
||||
|
||||
This module exercises persisted DatasetPermission rows and dataset permission
|
||||
checks with testcontainers-backed infrastructure instead of database-chain mocks.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
Dataset,
|
||||
DatasetPermission,
|
||||
DatasetPermissionEnum,
|
||||
)
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetPermissionTestDataFactory:
|
||||
"""Create persisted entities and request payloads for dataset permission integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
tenant: Tenant | None = None,
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""Create a real account and tenant with specified role."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
if tenant is None:
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add_all([account, tenant])
|
||||
else:
|
||||
db.session.add(account)
|
||||
|
||||
db.session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME,
|
||||
name: str = "Test Dataset",
|
||||
) -> Dataset:
|
||||
"""Create a real dataset with specified attributes."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description="desc",
|
||||
data_source_type="upload_file",
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
permission=permission,
|
||||
provider="vendor",
|
||||
retrieval_model={"top_k": 2},
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission(
|
||||
dataset_id: str,
|
||||
account_id: str,
|
||||
tenant_id: str,
|
||||
has_permission: bool = True,
|
||||
) -> DatasetPermission:
|
||||
"""Create a real DatasetPermission instance."""
|
||||
permission = DatasetPermission(
|
||||
dataset_id=dataset_id,
|
||||
account_id=account_id,
|
||||
tenant_id=tenant_id,
|
||||
has_permission=has_permission,
|
||||
)
|
||||
db.session.add(permission)
|
||||
db.session.commit()
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]:
|
||||
"""Build the request payload shape used by partial-member list updates."""
|
||||
return [{"user_id": user_id} for user_id in user_ids]
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceGetPartialMemberList:
|
||||
"""Verify partial-member list reads against persisted DatasetPermission rows."""
|
||||
|
||||
def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list with multiple members.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
expected_account_ids = [user_1.id, user_2.id, user_3.id]
|
||||
for account_id in expected_account_ids:
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list with single member.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
expected_account_ids = [user.id]
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert set(result) == set(expected_account_ids)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_dataset_partial_member_list_empty(self, db_session_with_containers):
|
||||
"""
|
||||
Test retrieving partial member list when no members exist.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceUpdatePartialMemberList:
|
||||
"""Verify partial-member list updates against persisted DatasetPermission rows."""
|
||||
|
||||
def test_update_partial_member_list_add_new_members(self, db_session_with_containers):
|
||||
"""
|
||||
Test adding new partial members to a dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {member_1.id, member_2.id}
|
||||
|
||||
def test_update_partial_member_list_replace_existing(self, db_session_with_containers):
|
||||
"""
|
||||
Test replacing existing partial members with new ones.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users)
|
||||
|
||||
new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id])
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert set(result) == {new_member_1.id, new_member_2.id}
|
||||
|
||||
def test_update_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
"""
|
||||
Test updating with empty member list (clearing all members).
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, [])
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant.id,
|
||||
dataset.id,
|
||||
DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]),
|
||||
)
|
||||
user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id])
|
||||
rollback_called = {"count": 0}
|
||||
original_rollback = db.session.rollback
|
||||
|
||||
# Act / Assert
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
|
||||
def _raise_commit():
|
||||
raise Exception("Database connection error")
|
||||
|
||||
def _rollback_and_mark():
|
||||
rollback_called["count"] += 1
|
||||
original_rollback()
|
||||
|
||||
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
|
||||
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
|
||||
with pytest.raises(Exception, match="Database connection error"):
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert rollback_called["count"] == 1
|
||||
assert result == [existing_member.id]
|
||||
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1
|
||||
|
||||
|
||||
class TestDatasetPermissionServiceClearPartialMemberList:
|
||||
"""Verify partial-member clearing against persisted DatasetPermission rows."""
|
||||
|
||||
def test_clear_partial_member_list_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test successful clearing of partial member list.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_empty_list(self, db_session_with_containers):
|
||||
"""
|
||||
Test clearing partial member list when no members exist.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert result == []
|
||||
|
||||
def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers):
|
||||
"""
|
||||
Test error handling and rollback on database error.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id])
|
||||
DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users)
|
||||
rollback_called = {"count": 0}
|
||||
original_rollback = db.session.rollback
|
||||
|
||||
# Act / Assert
|
||||
with pytest.MonkeyPatch.context() as mp:
|
||||
|
||||
def _raise_commit():
|
||||
raise Exception("Database connection error")
|
||||
|
||||
def _rollback_and_mark():
|
||||
rollback_called["count"] += 1
|
||||
original_rollback()
|
||||
|
||||
mp.setattr("services.dataset_service.db.session.commit", _raise_commit)
|
||||
mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark)
|
||||
with pytest.raises(Exception, match="Database connection error"):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset.id)
|
||||
|
||||
# Assert
|
||||
result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert rollback_called["count"] == 1
|
||||
assert set(result) == {member_1.id, member_2.id}
|
||||
assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetPermission:
|
||||
"""Verify dataset access checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
# Assert
|
||||
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert user.id in permissions
|
||||
|
||||
def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers):
|
||||
"""
|
||||
Test error when user without permission tries to access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_permission(dataset, user)
|
||||
|
||||
|
||||
class TestDatasetServiceCheckDatasetOperatorPermission:
|
||||
"""Verify operator permission checks against persisted partial-member permissions."""
|
||||
|
||||
def test_check_dataset_operator_permission_partial_members_with_permission_success(
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test that user with explicit permission can access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id)
|
||||
|
||||
# Act (should not raise)
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
|
||||
# Assert
|
||||
permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id)
|
||||
assert user.id in permissions
|
||||
|
||||
def test_check_dataset_operator_permission_partial_members_without_permission_error(
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""
|
||||
Test error when user without permission tries to access partial_members dataset.
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
|
||||
dataset = DatasetPermissionTestDataFactory.create_dataset(
|
||||
tenant.id,
|
||||
owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"):
|
||||
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
||||
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
@@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once_with(
|
||||
tenant_id=next_task["tenant_id"],
|
||||
dataset_id=next_task["dataset_id"],
|
||||
document_ids=next_task["document_ids"],
|
||||
)
|
||||
# apply_async is used by implementation; assert it was called once with expected kwargs
|
||||
assert task_dispatch_spy.apply_async.call_count == 1
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
|
||||
assert call_kwargs == {
|
||||
"tenant_id": next_task["tenant_id"],
|
||||
"dataset_id": next_task["dataset_id"],
|
||||
"document_ids": next_task["document_ids"],
|
||||
}
|
||||
set_waiting_spy.assert_called_once()
|
||||
delete_key_spy.assert_not_called()
|
||||
|
||||
@@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_not_called()
|
||||
task_dispatch_spy.apply_async.assert_not_called()
|
||||
delete_key_spy.assert_called_once()
|
||||
|
||||
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
||||
@@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once()
|
||||
task_dispatch_spy.apply_async.assert_called_once()
|
||||
|
||||
def test_sessions_close_on_successful_indexing(
|
||||
self,
|
||||
@@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == concurrency_limit
|
||||
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
|
||||
assert set_waiting_spy.call_count == concurrency_limit
|
||||
|
||||
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
||||
@@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == 3
|
||||
assert task_dispatch_spy.apply_async.call_count == 3
|
||||
for index, expected_task in enumerate(ordered_tasks):
|
||||
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
|
||||
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
||||
"""Skip limit checks when billing feature is disabled."""
|
||||
|
||||
@@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
assert mock_task_func.apply_async.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
calls = mock_task_func.apply_async.call_args_list
|
||||
sent_kwargs = calls[0][1]["kwargs"]
|
||||
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
@@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": ["waiting-doc-1"],
|
||||
}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant1_id,
|
||||
"dataset_id": dataset1_id,
|
||||
"document_ids": ["tenant1-doc-1"],
|
||||
}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group, pull 1 task a time by default
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
@@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group.apply_async
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
@@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
# Verify only tenant1's waiting task was processed (via group)
|
||||
assert mock_group.return_value.apply_async.called
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert first_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
@@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
@@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
|
||||
@@ -105,18 +105,26 @@ def app_model(
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
"""Mock for celery group() function that collects dispatched tasks.
|
||||
|
||||
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
|
||||
(e.g. producer) so production code can pass broker-related options without
|
||||
breaking tests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
self.last_apply_async_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self) -> None:
|
||||
def apply_async(self, **kwargs: Any) -> None:
|
||||
# Accept arbitrary kwargs like producer to be compatible with Celery
|
||||
self._applied = True
|
||||
self.last_apply_async_kwargs = kwargs
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
|
||||
181
api/tests/unit_tests/commands/test_clean_expired_messages.py
Normal file
181
api/tests/unit_tests/commands/test_clean_expired_messages.py
Normal file
@@ -0,0 +1,181 @@
|
||||
import datetime
|
||||
import re
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import click
|
||||
import pytest
|
||||
|
||||
from commands import clean_expired_messages
|
||||
|
||||
|
||||
def _mock_service() -> MagicMock:
|
||||
service = MagicMock()
|
||||
service.run.return_value = {
|
||||
"batches": 1,
|
||||
"total_messages": 10,
|
||||
"filtered_messages": 5,
|
||||
"total_deleted": 5,
|
||||
}
|
||||
return service
|
||||
|
||||
|
||||
def test_absolute_mode_calls_from_time_range():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=200,
|
||||
graceful_period=21,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
from_days_ago=None,
|
||||
before_days=None,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=200,
|
||||
dry_run=True,
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
|
||||
def test_relative_mode_before_days_only_calls_from_days():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
|
||||
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=500,
|
||||
graceful_period=14,
|
||||
start_from=None,
|
||||
end_before=None,
|
||||
from_days_ago=None,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
mock_from_days.assert_called_once_with(
|
||||
policy=policy,
|
||||
days=30,
|
||||
batch_size=500,
|
||||
dry_run=False,
|
||||
)
|
||||
mock_from_time_range.assert_not_called()
|
||||
|
||||
|
||||
def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
policy = object()
|
||||
service = _mock_service()
|
||||
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
|
||||
|
||||
with (
|
||||
patch("commands.create_message_clean_policy", return_value=policy),
|
||||
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
|
||||
patch("commands.MessagesCleanService.from_days") as mock_from_days,
|
||||
patch("commands.naive_utc_now", return_value=fixed_now),
|
||||
):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=1000,
|
||||
graceful_period=21,
|
||||
start_from=None,
|
||||
end_before=None,
|
||||
from_days_ago=60,
|
||||
before_days=30,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
mock_from_time_range.assert_called_once_with(
|
||||
policy=policy,
|
||||
start_from=fixed_now - datetime.timedelta(days=60),
|
||||
end_before=fixed_now - datetime.timedelta(days=30),
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("kwargs", "message"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"start_from": datetime.datetime(2024, 1, 1),
|
||||
"end_before": datetime.datetime(2024, 2, 1),
|
||||
"from_days_ago": None,
|
||||
"before_days": 30,
|
||||
},
|
||||
"mutually exclusive",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": datetime.datetime(2024, 1, 1),
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": None,
|
||||
},
|
||||
"Both --start-from and --end-before are required",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": 10,
|
||||
"before_days": None,
|
||||
},
|
||||
"--from-days-ago must be used together with --before-days",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": -1,
|
||||
},
|
||||
"--before-days must be >= 0",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": 30,
|
||||
"before_days": 30,
|
||||
},
|
||||
"--from-days-ago must be greater than --before-days",
|
||||
),
|
||||
(
|
||||
{
|
||||
"start_from": None,
|
||||
"end_before": None,
|
||||
"from_days_ago": None,
|
||||
"before_days": None,
|
||||
},
|
||||
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
|
||||
with pytest.raises(click.UsageError, match=re.escape(message)):
|
||||
clean_expired_messages.callback(
|
||||
batch_size=1000,
|
||||
graceful_period=21,
|
||||
start_from=kwargs["start_from"],
|
||||
end_before=kwargs["end_before"],
|
||||
from_days_ago=kwargs["from_days_ago"],
|
||||
before_days=kwargs["before_days"],
|
||||
dry_run=False,
|
||||
)
|
||||
@@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs")
|
||||
os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage")
|
||||
os.environ.setdefault("STORAGE_TYPE", "opendal")
|
||||
|
||||
# Add the API directory to Python path to ensure proper imports
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, PROJECT_DIR)
|
||||
|
||||
from core.db.session_factory import configure_session_factory, session_factory
|
||||
from extensions import ext_redis
|
||||
|
||||
|
||||
70
api/tests/unit_tests/controllers/common/test_errors.py
Normal file
70
api/tests/unit_tests/controllers/common/test_errors.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from controllers.common.errors import (
|
||||
BlockedFileExtensionError,
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
RemoteFileUploadError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
|
||||
|
||||
class TestFilenameNotExistsError:
|
||||
def test_defaults(self):
|
||||
error = FilenameNotExistsError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "The specified filename does not exist."
|
||||
|
||||
|
||||
class TestRemoteFileUploadError:
|
||||
def test_defaults(self):
|
||||
error = RemoteFileUploadError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.description == "Error uploading remote file."
|
||||
|
||||
|
||||
class TestFileTooLargeError:
|
||||
def test_defaults(self):
|
||||
error = FileTooLargeError()
|
||||
|
||||
assert error.code == 413
|
||||
assert error.error_code == "file_too_large"
|
||||
assert error.description == "File size exceeded. {message}"
|
||||
|
||||
|
||||
class TestUnsupportedFileTypeError:
|
||||
def test_defaults(self):
|
||||
error = UnsupportedFileTypeError()
|
||||
|
||||
assert error.code == 415
|
||||
assert error.error_code == "unsupported_file_type"
|
||||
assert error.description == "File type not allowed."
|
||||
|
||||
|
||||
class TestBlockedFileExtensionError:
|
||||
def test_defaults(self):
|
||||
error = BlockedFileExtensionError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "file_extension_blocked"
|
||||
assert error.description == "The file extension is blocked for security reasons."
|
||||
|
||||
|
||||
class TestTooManyFilesError:
|
||||
def test_defaults(self):
|
||||
error = TooManyFilesError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "too_many_files"
|
||||
assert error.description == "Only one file is allowed."
|
||||
|
||||
|
||||
class TestNoFileUploadedError:
|
||||
def test_defaults(self):
|
||||
error = NoFileUploadedError()
|
||||
|
||||
assert error.code == 400
|
||||
assert error.error_code == "no_file_uploaded"
|
||||
assert error.description == "Please upload your file."
|
||||
@@ -1,22 +1,95 @@
|
||||
from flask import Response
|
||||
|
||||
from controllers.common.file_response import enforce_download_for_html, is_html_content
|
||||
from controllers.common.file_response import (
|
||||
_normalize_mime_type,
|
||||
enforce_download_for_html,
|
||||
is_html_content,
|
||||
)
|
||||
|
||||
|
||||
class TestFileResponseHelpers:
|
||||
def test_is_html_content_detects_mime_type(self):
|
||||
class TestNormalizeMimeType:
|
||||
def test_returns_empty_string_for_none(self):
|
||||
assert _normalize_mime_type(None) == ""
|
||||
|
||||
def test_returns_empty_string_for_empty_string(self):
|
||||
assert _normalize_mime_type("") == ""
|
||||
|
||||
def test_normalizes_mime_type(self):
|
||||
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
|
||||
|
||||
|
||||
class TestIsHtmlContent:
|
||||
def test_detects_html_via_mime_type(self):
|
||||
mime_type = "text/html; charset=UTF-8"
|
||||
|
||||
result = is_html_content(mime_type, filename="file.txt", extension="txt")
|
||||
result = is_html_content(
|
||||
mime_type=mime_type,
|
||||
filename="file.txt",
|
||||
extension="txt",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_html_content_detects_extension(self):
|
||||
result = is_html_content("text/plain", filename="report.html", extension=None)
|
||||
def test_detects_html_via_extension_argument(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_enforce_download_for_html_sets_headers(self):
|
||||
def test_detects_html_via_filename_extension(self):
|
||||
result = is_html_content(
|
||||
mime_type="text/plain",
|
||||
filename="report.html",
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_no_html_detected_anywhere(self):
|
||||
"""
|
||||
Missing negative test:
|
||||
- MIME type is not HTML
|
||||
- filename has no HTML extension
|
||||
- extension argument is not HTML
|
||||
"""
|
||||
result = is_html_content(
|
||||
mime_type="application/json",
|
||||
filename="data.json",
|
||||
extension="json",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_all_inputs_are_none(self):
|
||||
result = is_html_content(
|
||||
mime_type=None,
|
||||
filename=None,
|
||||
extension=None,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestEnforceDownloadForHtml:
|
||||
def test_sets_attachment_when_filename_missing(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
response,
|
||||
mime_type="text/html",
|
||||
filename=None,
|
||||
extension="html",
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert response.headers["Content-Disposition"] == "attachment"
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_sets_headers_when_filename_present(self):
|
||||
response = Response("payload", mimetype="text/html")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
@@ -27,11 +100,12 @@ class TestFileResponseHelpers:
|
||||
)
|
||||
|
||||
assert updated is True
|
||||
assert "attachment" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Disposition"].startswith("attachment")
|
||||
assert "unsafe.html" in response.headers["Content-Disposition"]
|
||||
assert response.headers["Content-Type"] == "application/octet-stream"
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
def test_enforce_download_for_html_no_change_for_non_html(self):
|
||||
def test_does_not_modify_response_for_non_html_content(self):
|
||||
response = Response("payload", mimetype="text/plain")
|
||||
|
||||
updated = enforce_download_for_html(
|
||||
|
||||
188
api/tests/unit_tests/controllers/common/test_helpers.py
Normal file
188
api/tests/unit_tests/controllers/common/test_helpers.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from controllers.common import helpers
|
||||
from controllers.common.helpers import FileInfo, guess_file_info_from_response
|
||||
|
||||
|
||||
def make_response(
|
||||
url="https://example.com/file.txt",
|
||||
headers=None,
|
||||
content=None,
|
||||
):
|
||||
return httpx.Response(
|
||||
200,
|
||||
request=httpx.Request("GET", url),
|
||||
headers=headers or {},
|
||||
content=content or b"",
|
||||
)
|
||||
|
||||
|
||||
class TestGuessFileInfoFromResponse:
|
||||
def test_filename_from_url(self):
|
||||
response = make_response(
|
||||
url="https://example.com/test.pdf",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "test.pdf"
|
||||
assert info.extension == ".pdf"
|
||||
assert info.mimetype == "application/pdf"
|
||||
|
||||
def test_filename_from_content_disposition(self):
|
||||
headers = {
|
||||
"Content-Disposition": "attachment; filename=myfile.csv",
|
||||
"Content-Type": "text/csv",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
headers=headers,
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.filename == "myfile.csv"
|
||||
assert info.extension == ".csv"
|
||||
assert info.mimetype == "text/csv"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("magic_available", "expected_ext"),
|
||||
[
|
||||
(True, "txt"),
|
||||
(False, "bin"),
|
||||
],
|
||||
)
|
||||
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
|
||||
if magic_available:
|
||||
if helpers.magic is None:
|
||||
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
|
||||
else:
|
||||
monkeypatch.setattr(helpers, "magic", None)
|
||||
|
||||
response = make_response(
|
||||
url="https://example.com/",
|
||||
content=b"Hello World",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
name, ext = info.filename.split(".")
|
||||
UUID(name)
|
||||
assert ext == expected_ext
|
||||
|
||||
def test_mimetype_from_header_when_unknown(self):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = make_response(
|
||||
url="https://example.com/file.unknown",
|
||||
headers=headers,
|
||||
content=b'{"a": 1}',
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.mimetype == "application/json"
|
||||
|
||||
def test_extension_added_when_missing(self):
|
||||
headers = {"Content-Type": "image/png"}
|
||||
response = make_response(
|
||||
url="https://example.com/image",
|
||||
headers=headers,
|
||||
content=b"fakepngdata",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".png"
|
||||
assert info.filename.endswith(".png")
|
||||
|
||||
def test_content_length_used_as_size(self):
|
||||
headers = {
|
||||
"Content-Length": "1234",
|
||||
"Content-Type": "text/plain",
|
||||
}
|
||||
response = make_response(
|
||||
url="https://example.com/a.txt",
|
||||
headers=headers,
|
||||
content=b"a" * 1234,
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == 1234
|
||||
|
||||
def test_size_minus_one_when_header_missing(self):
|
||||
response = make_response(url="https://example.com/a.txt")
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.size == -1
|
||||
|
||||
def test_fallback_to_bin_extension(self):
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
response = make_response(
|
||||
url="https://example.com/download",
|
||||
headers=headers,
|
||||
content=b"\x00\x01\x02\x03",
|
||||
)
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert info.extension == ".bin"
|
||||
assert info.filename.endswith(".bin")
|
||||
|
||||
def test_return_type(self):
|
||||
response = make_response()
|
||||
|
||||
info = guess_file_info_from_response(response)
|
||||
|
||||
assert isinstance(info, FileInfo)
|
||||
|
||||
|
||||
class TestMagicImportWarnings:
|
||||
@pytest.mark.parametrize(
|
||||
("platform_name", "expected_message"),
|
||||
[
|
||||
("Windows", "pip install python-magic-bin"),
|
||||
("Darwin", "brew install libmagic"),
|
||||
("Linux", "sudo apt-get install libmagic1"),
|
||||
("Other", "install `libmagic`"),
|
||||
],
|
||||
)
|
||||
def test_magic_import_warning_per_platform(
|
||||
self,
|
||||
monkeypatch,
|
||||
platform_name,
|
||||
expected_message,
|
||||
):
|
||||
import builtins
|
||||
import importlib
|
||||
|
||||
# Force ImportError when "magic" is imported
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "magic":
|
||||
raise ImportError("No module named magic")
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
monkeypatch.setattr("platform.system", lambda: platform_name)
|
||||
|
||||
# Remove helpers so it imports fresh
|
||||
import sys
|
||||
|
||||
original_helpers = sys.modules.get(helpers.__name__)
|
||||
sys.modules.pop(helpers.__name__, None)
|
||||
|
||||
try:
|
||||
with pytest.warns(UserWarning, match="To use python-magic") as warning:
|
||||
imported_helpers = importlib.import_module(helpers.__name__)
|
||||
assert expected_message in str(warning[0].message)
|
||||
finally:
|
||||
if original_helpers is not None:
|
||||
sys.modules[helpers.__name__] = original_helpers
|
||||
189
api/tests/unit_tests/controllers/common/test_schema.py
Normal file
189
api/tests/unit_tests/controllers/common/test_schema.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import sys
|
||||
from enum import StrEnum
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ProductModel(BaseModel):
|
||||
id: int
|
||||
price: float
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_console_ns():
|
||||
"""Mock the console_ns to avoid circular imports during test collection."""
|
||||
mock_ns = MagicMock(spec=Namespace)
|
||||
mock_ns.models = {}
|
||||
|
||||
# Inject mock before importing schema module
|
||||
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
|
||||
yield mock_ns
|
||||
|
||||
|
||||
def test_default_ref_template_value():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
|
||||
|
||||
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
|
||||
|
||||
|
||||
def test_register_schema_model_calls_namespace_schema_model():
|
||||
from controllers.common.schema import register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "UserModel"
|
||||
assert isinstance(schema, dict)
|
||||
assert "properties" in schema
|
||||
|
||||
|
||||
def test_register_schema_model_passes_schema_from_pydantic():
|
||||
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_model(namespace, UserModel)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
|
||||
assert schema == expected_schema
|
||||
|
||||
|
||||
def test_register_schema_models_registers_multiple_models():
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["UserModel", "ProductModel"]
|
||||
|
||||
|
||||
def test_register_schema_models_calls_register_schema_model(monkeypatch):
|
||||
from controllers.common.schema import register_schema_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_register(ns, model):
|
||||
calls.append((ns, model))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.common.schema.register_schema_model",
|
||||
fake_register,
|
||||
)
|
||||
|
||||
register_schema_models(namespace, UserModel, ProductModel)
|
||||
|
||||
assert calls == [
|
||||
(namespace, UserModel),
|
||||
(namespace, ProductModel),
|
||||
]
|
||||
|
||||
|
||||
class StatusEnum(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class PriorityEnum(StrEnum):
|
||||
HIGH = "high"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
def test_get_or_create_model_returns_existing_model(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"TestModel": existing_model}
|
||||
|
||||
result = get_or_create_model("TestModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
mock_console_ns.models = {}
|
||||
new_model = MagicMock()
|
||||
mock_console_ns.model.return_value = new_model
|
||||
field_def = {"name": {"type": "string"}}
|
||||
|
||||
result = get_or_create_model("NewModel", field_def)
|
||||
|
||||
assert result == new_model
|
||||
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
|
||||
|
||||
|
||||
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
|
||||
from controllers.common.schema import get_or_create_model
|
||||
|
||||
existing_model = MagicMock()
|
||||
mock_console_ns.models = {"ExistingModel": existing_model}
|
||||
|
||||
result = get_or_create_model("ExistingModel", {"key": "value"})
|
||||
|
||||
assert result == existing_model
|
||||
mock_console_ns.model.assert_not_called()
|
||||
|
||||
|
||||
def test_register_enum_models_registers_single_enum():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
namespace.schema_model.assert_called_once()
|
||||
|
||||
model_name, schema = namespace.schema_model.call_args.args
|
||||
|
||||
assert model_name == "StatusEnum"
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
|
||||
def test_register_enum_models_registers_multiple_enums():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum, PriorityEnum)
|
||||
|
||||
assert namespace.schema_model.call_count == 2
|
||||
|
||||
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
|
||||
assert called_names == ["StatusEnum", "PriorityEnum"]
|
||||
|
||||
|
||||
def test_register_enum_models_uses_correct_ref_template():
|
||||
from controllers.common.schema import register_enum_models
|
||||
|
||||
namespace = MagicMock(spec=Namespace)
|
||||
|
||||
register_enum_models(namespace, StatusEnum)
|
||||
|
||||
schema = namespace.schema_model.call_args.args[1]
|
||||
|
||||
# Verify the schema contains enum values
|
||||
assert "enum" in schema or "anyOf" in schema
|
||||
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Shared fixtures for controllers.web unit tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Minimal Flask app for request contexts."""
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class FakeSession:
|
||||
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: FakeSession | None = None):
|
||||
self.session = session or FakeSession()
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def make_app_model(
|
||||
*,
|
||||
app_id: str = "app-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
mode: str = "chat",
|
||||
enable_site: bool = True,
|
||||
status: str = "normal",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake App model with common defaults."""
|
||||
tenant = SimpleNamespace(
|
||||
id=tenant_id,
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
tenant=tenant,
|
||||
mode=mode,
|
||||
enable_site=enable_site,
|
||||
status=status,
|
||||
workflow=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
|
||||
def make_end_user(
|
||||
*,
|
||||
user_id: str = "end-user-1",
|
||||
session_id: str = "session-1",
|
||||
external_user_id: str = "ext-user-1",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake EndUser model with common defaults."""
|
||||
return SimpleNamespace(
|
||||
id=user_id,
|
||||
session_id=session_id,
|
||||
external_user_id=external_user_id,
|
||||
)
|
||||
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Unit tests for controllers.web.app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
|
||||
from controllers.web.error import AppUnavailableError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppParameterApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppParameterApi:
|
||||
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {"opening_statement": "Hello"}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
|
||||
result = AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
|
||||
|
||||
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
|
||||
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=config)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
|
||||
|
||||
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppMeta
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppMeta:
|
||||
@patch("controllers.web.app.AppService")
|
||||
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
|
||||
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context("/meta"):
|
||||
result = AppMeta().get(app_model, SimpleNamespace())
|
||||
|
||||
assert result == {"tool_icons": {}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppAccessMode
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppAccessMode:
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "public"}
|
||||
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_access_mode_with_app_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="internal")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "internal"}
|
||||
mock_access.assert_called_once_with("app-1")
|
||||
|
||||
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_resolves_app_code_to_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="external")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appCode=code1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
mock_resolve.assert_called_once_with("code1")
|
||||
mock_access.assert_called_once_with("resolved-id")
|
||||
assert result == {"accessMode": "external"}
|
||||
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode"):
|
||||
with pytest.raises(ValueError, match="appId or appCode"):
|
||||
AppAccessMode().get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppWebAuthPermission
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppWebAuthPermission:
|
||||
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
|
||||
result = AppWebAuthPermission().get()
|
||||
|
||||
assert result == {"result": True}
|
||||
|
||||
def test_raises_when_missing_app_id(self, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(ValueError, match="appId"):
|
||||
AppWebAuthPermission().get()
|
||||
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Unit tests for controllers.web.audio endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.audio import AudioApi, TextApi
|
||||
from controllers.web.error import (
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioApi (audio-to-text)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAudioApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
|
||||
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result = AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == {"text": "hello"}
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
|
||||
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b""), "empty.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
|
||||
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"big"), "big.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
|
||||
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(UnsupportedAudioTypeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderNotSupportSpeechToTextServiceError(),
|
||||
)
|
||||
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError(description="no token"),
|
||||
)
|
||||
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
|
||||
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
|
||||
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TextApi (text-to-audio)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTextApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello", "voice": "alloy"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
result = TextApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == "audio-bytes"
|
||||
mock_tts.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_tts",
|
||||
side_effect=InvokeError(description="invoke failed"),
|
||||
)
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
TextApi().post(_app_model(), _end_user())
|
||||
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Unit tests for controllers.web.completion endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "test"}
|
||||
mock_gen.return_value = "response-obj"
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
result = CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "hi"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "hi"}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
result = ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "reply"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=InvokeError(description="rate limit"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "x"}
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from controllers.web.error import NotChatAppError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
name="Test",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
id=str(c_id),
|
||||
name="New Name",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_rename.return_value = conv
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
|
||||
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
|
||||
@patch(
|
||||
"controllers.web.conversation.ConversationService.rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Unit tests for controllers.web.error HTTP exception classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
InvalidArgumentError,
|
||||
InvokeRateLimitError,
|
||||
NoAudioUploadedError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotFoundError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
WebAppAuthAccessDeniedError,
|
||||
WebAppAuthRequiredError,
|
||||
WebFormRateLimitExceededError,
|
||||
)
|
||||
|
||||
_ERROR_SPECS: list[tuple[type, str, int]] = [
|
||||
(AppUnavailableError, "app_unavailable", 400),
|
||||
(NotCompletionAppError, "not_completion_app", 400),
|
||||
(NotChatAppError, "not_chat_app", 400),
|
||||
(NotWorkflowAppError, "not_workflow_app", 400),
|
||||
(ConversationCompletedError, "conversation_completed", 400),
|
||||
(ProviderNotInitializeError, "provider_not_initialize", 400),
|
||||
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
|
||||
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
|
||||
(CompletionRequestError, "completion_request_error", 400),
|
||||
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
|
||||
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
|
||||
(NoAudioUploadedError, "no_audio_uploaded", 400),
|
||||
(AudioTooLargeError, "audio_too_large", 413),
|
||||
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
|
||||
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
|
||||
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
|
||||
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
|
||||
(InvokeRateLimitError, "rate_limit_error", 429),
|
||||
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
|
||||
(NotFoundError, "not_found", 404),
|
||||
(InvalidArgumentError, "invalid_param", 400),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls", "expected_code", "expected_status"),
|
||||
_ERROR_SPECS,
|
||||
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
|
||||
)
|
||||
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
|
||||
"""Each error class exposes the correct error_code and HTTP status code."""
|
||||
assert cls.error_code == expected_code
|
||||
assert cls.code == expected_status
|
||||
|
||||
|
||||
def test_error_classes_have_description() -> None:
|
||||
"""Every error class has a description (string or None for generic errors)."""
|
||||
# NotFoundError and InvalidArgumentError use None description by design
|
||||
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
|
||||
for cls, _, _ in _ERROR_SPECS:
|
||||
if cls in _NO_DESCRIPTION:
|
||||
continue
|
||||
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
|
||||
assert len(cls.description) > 0, f"{cls.__name__} has empty description"
|
||||
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for controllers.web.feature endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.feature import SystemFeatureApi
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
with app.test_request_context("/system-features"):
|
||||
result = SystemFeatureApi().get()
|
||||
|
||||
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.assert_called_once()
|
||||
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
# Verify it's a bare Resource, not WebApiResource
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
assert issubclass(SystemFeatureApi, Resource)
|
||||
assert not issubclass(SystemFeatureApi, WebApiResource)
|
||||
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Unit tests for controllers.web.files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.web.files import FileApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
class TestFileApi:
|
||||
def test_no_file_uploaded(self, app: Flask) -> None:
|
||||
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_too_many_files(self, app: Flask) -> None:
|
||||
data = {
|
||||
"file": (BytesIO(b"a"), "a.txt"),
|
||||
"file2": (BytesIO(b"b"), "b.txt"),
|
||||
}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
# Now has "file" key but len(request.files) > 1
|
||||
with pytest.raises(TooManyFilesError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_filename_missing(self, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"content"), "")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="test.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result, status = FileApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file-1"
|
||||
assert result["name"] == "test.txt"
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
import services.errors.file
|
||||
|
||||
mock_db.engine = "engine"
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
description="max 10MB"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big"), "big.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.web.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageFeedbackApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageFeedbackApi:
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "like", "content": "great"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": None}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "dislike"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageMoreLikeThisApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
|
||||
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
|
||||
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"answer": "similar"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MoreLikeThisDisabledError(),
|
||||
)
|
||||
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageSuggestedQuestionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result["data"] == ["What about X?", "Tell me more about Y."]
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotFound, match="Message not found"):
|
||||
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportService,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_none() -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_enterprise_webapp_user_id("token")
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
|
||||
assert decode_enterprise_webapp_user_id("token") == decoded
|
||||
|
||||
|
||||
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
|
||||
|
||||
decoded = {"auth_type": "public"}
|
||||
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
|
||||
assert result == "resp"
|
||||
|
||||
|
||||
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
|
||||
|
||||
|
||||
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
if _scalar_side_effect.calls == 1:
|
||||
return site
|
||||
if _scalar_side_effect.calls == 2:
|
||||
return app_model
|
||||
return None
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
|
||||
|
||||
|
||||
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
counts = [1, 0]
|
||||
|
||||
def _scalar(*_args, **_kwargs):
|
||||
return counts.pop(0)
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
session_id = generate_session_id()
|
||||
assert session_id
|
||||
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Unit tests for Pydantic models defined in controllers.web modules.
|
||||
|
||||
Covers validation logic, field defaults, constraints, and custom validators
|
||||
for all ~15 Pydantic models across the web controller layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.app import AppAccessModeQuery
|
||||
|
||||
|
||||
class TestAppAccessModeQuery:
|
||||
def test_alias_resolution(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
|
||||
assert q.app_id == "abc"
|
||||
assert q.app_code == "xyz"
|
||||
|
||||
def test_defaults_to_none(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({})
|
||||
assert q.app_id is None
|
||||
assert q.app_code is None
|
||||
|
||||
def test_accepts_snake_case(self) -> None:
|
||||
q = AppAccessModeQuery(app_id="id1", app_code="code1")
|
||||
assert q.app_id == "id1"
|
||||
assert q.app_code == "code1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audio.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.audio import TextToAudioPayload
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = TextToAudioPayload.model_validate({})
|
||||
assert p.message_id is None
|
||||
assert p.voice is None
|
||||
assert p.text is None
|
||||
assert p.streaming is None
|
||||
|
||||
def test_valid_uuid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = TextToAudioPayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_none_message_id_passthrough(self) -> None:
|
||||
p = TextToAudioPayload(message_id=None)
|
||||
assert p.message_id is None
|
||||
|
||||
def test_invalid_uuid_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
TextToAudioPayload(message_id="not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# completion.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
|
||||
|
||||
class TestCompletionMessagePayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = CompletionMessagePayload(inputs={})
|
||||
assert p.query == ""
|
||||
assert p.files is None
|
||||
assert p.response_mode is None
|
||||
assert p.retriever_from == "web_app"
|
||||
|
||||
def test_accepts_full_payload(self) -> None:
|
||||
p = CompletionMessagePayload(
|
||||
inputs={"key": "val"},
|
||||
query="test",
|
||||
files=[{"id": "f1"}],
|
||||
response_mode="streaming",
|
||||
)
|
||||
assert p.response_mode == "streaming"
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_invalid_response_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CompletionMessagePayload(inputs={}, response_mode="invalid")
|
||||
|
||||
|
||||
class TestChatMessagePayload:
|
||||
def test_valid_uuid_fields(self) -> None:
|
||||
cid = str(uuid4())
|
||||
pid = str(uuid4())
|
||||
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
|
||||
assert p.conversation_id == cid
|
||||
assert p.parent_message_id == pid
|
||||
|
||||
def test_none_uuid_fields(self) -> None:
|
||||
p = ChatMessagePayload(inputs={}, query="hi")
|
||||
assert p.conversation_id is None
|
||||
assert p.parent_message_id is None
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
|
||||
|
||||
def test_invalid_parent_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
|
||||
|
||||
def test_query_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessagePayload(inputs={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conversation.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = ConversationListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
assert q.pinned is None
|
||||
assert q.sort_by == "-updated_at"
|
||||
|
||||
def test_limit_lower_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_limit_upper_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
def test_limit_boundaries_valid(self) -> None:
|
||||
assert ConversationListQuery(limit=1).limit == 1
|
||||
assert ConversationListQuery(limit=100).limit == 100
|
||||
|
||||
def test_valid_sort_by_options(self) -> None:
|
||||
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
|
||||
assert ConversationListQuery(sort_by=opt).sort_by == opt
|
||||
|
||||
def test_invalid_sort_by(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(sort_by="invalid")
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
assert ConversationListQuery(last_id=uid).last_id == uid
|
||||
|
||||
def test_invalid_last_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ConversationListQuery(last_id="not-uuid")
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
def test_auto_generate_true_no_name_required(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=True)
|
||||
assert p.name is None
|
||||
|
||||
def test_auto_generate_false_requires_name(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
|
||||
def test_auto_generate_false_blank_name_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False, name=" ")
|
||||
|
||||
def test_auto_generate_false_with_valid_name(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
|
||||
assert p.name == "My Chat"
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
p = ConversationRenamePayload(name="test")
|
||||
assert p.auto_generate is False
|
||||
assert p.name == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
def test_valid_query(self) -> None:
|
||||
cid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid)
|
||||
assert q.conversation_id == cid
|
||||
assert q.first_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id="bad")
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=101)
|
||||
|
||||
def test_valid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
fid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid, first_id=fid)
|
||||
assert q.first_id == fid
|
||||
|
||||
def test_invalid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id=cid, first_id="invalid")
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = MessageFeedbackPayload()
|
||||
assert p.rating is None
|
||||
assert p.content is None
|
||||
|
||||
def test_valid_ratings(self) -> None:
|
||||
assert MessageFeedbackPayload(rating="like").rating == "like"
|
||||
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
|
||||
|
||||
def test_invalid_rating(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageFeedbackPayload(rating="neutral")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisQuery:
|
||||
def test_valid_modes(self) -> None:
|
||||
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
|
||||
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
|
||||
|
||||
def test_invalid_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery(response_mode="invalid")
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remote_files.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.remote_files import RemoteFileUploadPayload
|
||||
|
||||
|
||||
class TestRemoteFileUploadPayload:
|
||||
def test_valid_url(self) -> None:
|
||||
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
|
||||
assert str(p.url) == "https://example.com/file.pdf"
|
||||
|
||||
def test_invalid_url(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload(url="not-a-url")
|
||||
|
||||
def test_url_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# saved_message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
|
||||
|
||||
class TestSavedMessageListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = SavedMessageListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=101)
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
q = SavedMessageListQuery(last_id=uid)
|
||||
assert q.last_id == uid
|
||||
|
||||
def test_empty_last_id(self) -> None:
|
||||
q = SavedMessageListQuery(last_id="")
|
||||
assert q.last_id == ""
|
||||
|
||||
|
||||
class TestSavedMessageCreatePayload:
|
||||
def test_valid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = SavedMessageCreatePayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageCreatePayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# workflow.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.workflow import WorkflowRunPayload
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={})
|
||||
assert p.inputs == {}
|
||||
assert p.files is None
|
||||
|
||||
def test_with_files(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_inputs_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowRunPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# forgot_password.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestForgotPasswordSendPayload:
|
||||
def test_valid_email(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="user@example.com")
|
||||
assert p.email == "user@example.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
ForgotPasswordSendPayload(email="not-an-email")
|
||||
|
||||
def test_language_optional(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
|
||||
class TestForgotPasswordCheckPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.email == "a@b.com"
|
||||
assert p.code == "1234"
|
||||
assert p.token == "tok"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
|
||||
|
||||
|
||||
class TestForgotPasswordResetPayload:
|
||||
def test_valid_passwords(self) -> None:
|
||||
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
|
||||
assert p.new_password == "Valid1234"
|
||||
|
||||
def test_weak_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
|
||||
|
||||
def test_letters_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
|
||||
|
||||
def test_digits_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# login.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
|
||||
|
||||
|
||||
class TestLoginPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = LoginPayload(email="a@b.com", password="Valid1234")
|
||||
assert p.email == "a@b.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
LoginPayload(email="bad", password="Valid1234")
|
||||
|
||||
def test_weak_password(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
LoginPayload(email="a@b.com", password="weak")
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
def test_with_language(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
|
||||
assert p.language == "zh-Hans"
|
||||
|
||||
|
||||
class TestEmailCodeLoginVerifyPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.code == "1234"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")
|
||||
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Unit tests for controllers.web.remote_files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
|
||||
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileInfoApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileInfoApi:
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
|
||||
mock_proxy.head.return_value = mock_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
|
||||
|
||||
assert result["file_type"] == "application/pdf"
|
||||
assert result["file_length"] == 1024
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 405 # Method not allowed
|
||||
get_resp = MagicMock()
|
||||
get_resp.status_code = 200
|
||||
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
|
||||
get_resp.raise_for_status = MagicMock()
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
|
||||
|
||||
assert result["file_type"] == "text/plain"
|
||||
mock_proxy.get.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileUploadApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileUploadApi:
|
||||
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
|
||||
@patch("controllers.web.remote_files.FileService")
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
@patch("controllers.web.remote_files.db")
|
||||
def test_upload_success(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_file_svc_cls: MagicMock,
|
||||
mock_signed: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_ns.payload = {"url": "https://example.com/file.pdf"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
head_resp.content = b"pdf-content"
|
||||
head_resp.request.method = "HEAD"
|
||||
mock_proxy.head.return_value = head_resp
|
||||
get_resp = MagicMock()
|
||||
get_resp.content = b"pdf-content"
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
|
||||
)
|
||||
mock_file_svc_cls.is_file_size_within_limit.return_value = True
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="f-1",
|
||||
name="file.pdf",
|
||||
size=100,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "f-1"
|
||||
|
||||
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_file_too_large(
|
||||
self,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_size_check: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_ns.payload = {"url": "https://example.com/big.zip"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
|
||||
)
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
import httpx
|
||||
|
||||
mock_ns.payload = {"url": "https://example.com/bad"}
|
||||
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(RemoteFileUploadError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Unit tests for controllers.web.saved_message endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (GET)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiGet:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().get(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
|
||||
|
||||
with app.test_request_context("/saved-messages?limit=20"):
|
||||
result = SavedMessageListApi().get(_completion_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (POST)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiPost:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save")
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
msg_id = str(uuid4())
|
||||
mock_ns.payload = {"message_id": msg_id}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
result = SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"message_id": str(uuid4())}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageApi (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageApi:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Unit tests for controllers.web.site endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
|
||||
|
||||
def _tenant(*, status: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=status,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
|
||||
)
|
||||
|
||||
|
||||
def _site() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
@patch("controllers.web.site.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
# marshal_with serializes AppSiteInfo to a dict
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteInfo:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = _tenant()
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
assert info.end_user_id == "eu-1"
|
||||
assert info.enable_site is True
|
||||
assert info.plan == "basic"
|
||||
assert info.can_replace_logo is False
|
||||
assert info.model_config is None
|
||||
|
||||
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
)
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
assert info.custom_config["remove_webapp_brand"] is True
|
||||
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
|
||||
@@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
import services.errors.account
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
@@ -89,3 +90,114 @@ class TestEmailCodeLoginApi:
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
|
||||
@patch("controllers.web.login.WebAppAuthService.authenticate")
|
||||
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
|
||||
mock_auth.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.get_json()["data"]["access_token"] == "access-tok"
|
||||
mock_auth.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountLoginError(),
|
||||
)
|
||||
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountPasswordError(),
|
||||
)
|
||||
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
|
||||
class TestLoginStatusApi:
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
|
||||
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/login/status"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token")
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_public_app_user_logged_in(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_decode.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is True
|
||||
assert result["app_logged_in"] is True
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_private_app_passport_fails(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
|
||||
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/logout", method="POST"):
|
||||
response = LogoutApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportResource,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_enterprise_webapp_user_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeEnterpriseWebappUserId:
|
||||
def test_none_token_returns_none(self) -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "webapp_login_token",
|
||||
"user_id": "u1",
|
||||
}
|
||||
result = decode_enterprise_webapp_user_id("valid-jwt")
|
||||
assert result["user_id"] == "u1"
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "other_source",
|
||||
}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("bad-jwt")
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("no-source-jwt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_session_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSessionId:
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = 0
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert len(sid) == 36 # UUID format
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
|
||||
# First call returns count=1 (collision), second returns 0
|
||||
mock_db.session.scalar.side_effect = [1, 0]
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert mock_db.session.scalar.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exchange_token_for_existing_web_user
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestExchangeTokenForExistingWebUser:
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="external"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="internal"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = None
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PassportResource.get
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPassportResource:
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
with app.test_request_context("/passport"):
|
||||
with pytest.raises(Unauthorized, match="X-App-Code"):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_creates_new_end_user_when_no_user_id(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_gen_session: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
mock_passport_cls.return_value.issue.return_value = "issued-token"
|
||||
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "issued-token"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_reuses_existing_end_user_when_user_id_provided(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
|
||||
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
|
||||
mock_passport_cls.return_value.issue.return_value = "reused-token"
|
||||
|
||||
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "reused-token"
|
||||
# Should not create a new end user
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_db.session.scalar.return_value = None
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
|
||||
mock_db.session.scalar.side_effect = [site, disabled_app]
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for controllers.web.workflow endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import (
|
||||
NotWorkflowAppError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="workflow")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowRunApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowRunApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowRunApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
|
||||
@patch("controllers.web.workflow.AppGenerateService.generate")
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {"key": "val"}}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
result = WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTaskStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
|
||||
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
|
||||
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_legacy.assert_called_once_with("task-1")
|
||||
mock_graph.assert_called_once_with("task-1")
|
||||
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Unit tests for controllers.web.workflow_events endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowEventsApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowEventsApi:
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="other-app",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_created_by_end_user(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="other-user",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_finished_run_returns_sse_response(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
|
||||
) -> None:
|
||||
from datetime import datetime
|
||||
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
finish_response = MagicMock()
|
||||
finish_response.model_dump.return_value = {"task_id": "run-1"}
|
||||
finish_response.event.value = "workflow_finished"
|
||||
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from controllers.web.wraps import (
|
||||
_validate_user_accessibility,
|
||||
_validate_webapp_token,
|
||||
decode_jwt_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "webapp"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_non_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "other"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_no_source(self) -> None:
|
||||
decoded = {}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
app_web_auth_enabled=False,
|
||||
system_webapp_auth_enabled=False,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_user_id_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=SimpleNamespace(access_mode="internal"),
|
||||
)
|
||||
|
||||
def test_missing_webapp_settings_raises(self) -> None:
|
||||
decoded = {"user_id": "u1"}
|
||||
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_auth_type_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "granted_at": 1}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
def test_missing_granted_at_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_internal_auth_type_checks_workspace_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
|
||||
) -> None:
|
||||
mock_workspace_sso.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_passes_when_granted_after_sso_update(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
def test_permission_check_denies_unauthorized_user(
|
||||
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
|
||||
) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
|
||||
settings = SimpleNamespace(access_mode="internal")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@@ -9,8 +9,16 @@ import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.app.entities.queue_entities import (
|
||||
QueuePingEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.enums import WorkflowExecutionStatus
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
@@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
|
||||
|
||||
def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowSucceededEvent(outputs={})
|
||||
responses = list(pipeline._handle_workflow_succeeded_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_id = "workflow-1"
|
||||
pipeline._ensure_workflow_initialized = mock.Mock()
|
||||
runtime_state = SimpleNamespace()
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state)
|
||||
pipeline._handle_advanced_chat_message_end_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)])
|
||||
)
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace(
|
||||
event=StreamEvent.WORKFLOW_FINISHED,
|
||||
data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED),
|
||||
)
|
||||
|
||||
event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
responses = list(pipeline._handle_workflow_partial_success_event(event))
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED]
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_succeeded() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
succeeded_event = QueueWorkflowSucceededEvent(outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=succeeded_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_succeeded_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
||||
|
||||
def test_process_stream_response_breaks_after_workflow_partial_success() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={})
|
||||
ping_event = QueuePingEvent()
|
||||
queue_messages = [
|
||||
SimpleNamespace(event=partial_event),
|
||||
SimpleNamespace(event=ping_event),
|
||||
]
|
||||
|
||||
pipeline._conversation_name_generate_thread = None
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages)
|
||||
pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING))
|
||||
pipeline._handle_workflow_partial_success_event = mock.Mock(
|
||||
return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)])
|
||||
)
|
||||
|
||||
responses = list(pipeline._process_stream_response())
|
||||
|
||||
assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED]
|
||||
pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None)
|
||||
pipeline._base_task_pipeline.ping_stream_response.assert_not_called()
|
||||
|
||||
@@ -124,12 +124,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
def fake_thread(*args, **kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import dify_graph.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
|
||||
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
|
||||
|
||||
This test suite ensures that the files array is correctly populated in the message_end
|
||||
SSE event, which is critical for vision/image chat responses to render correctly.
|
||||
|
||||
Test Coverage:
|
||||
- Files array populated when MessageFile records exist
|
||||
- Files array is None when no MessageFile records exist
|
||||
- Correct signed URL generation for LOCAL_FILE transfer method
|
||||
- Correct URL handling for REMOTE_URL transfer method
|
||||
- Correct URL handling for TOOL_FILE transfer method
|
||||
- Proper file metadata formatting (filename, mime_type, size, extension)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestMessageEndStreamResponseFiles:
|
||||
"""Test suite for files array population in message_end SSE event."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline(self):
|
||||
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
|
||||
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
|
||||
pipeline._message_id = str(uuid.uuid4())
|
||||
pipeline._task_state = Mock()
|
||||
pipeline._task_state.metadata = Mock()
|
||||
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
|
||||
pipeline._task_state.llm_result = Mock()
|
||||
pipeline._task_state.llm_result.usage = Mock()
|
||||
pipeline._application_generate_entity = Mock()
|
||||
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_local(self):
|
||||
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_remote(self):
|
||||
"""Create a mock MessageFile with REMOTE_URL transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_tool(self):
|
||||
"""Create a mock MessageFile with TOOL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self, mock_message_file_local):
|
||||
"""Create a mock UploadFile."""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = mock_message_file_local.upload_file_id
|
||||
upload_file.name = "test_image.png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.size = 1024
|
||||
upload_file.extension = "png"
|
||||
return upload_file
|
||||
|
||||
def test_message_end_with_no_files(self, mock_pipeline):
|
||||
"""Test that files array is None when no MessageFile records exist."""
|
||||
# Arrange
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is None
|
||||
assert result.id == mock_pipeline._message_id
|
||||
assert result.metadata == {"test": "metadata"}
|
||||
|
||||
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
|
||||
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_local.id
|
||||
assert file_dict["filename"] == "test_image.png"
|
||||
assert file_dict["mime_type"] == "image/png"
|
||||
assert file_dict["size"] == 1024
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
|
||||
assert "https://example.com/signed-url" in file_dict["url"]
|
||||
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
|
||||
assert file_dict["remote_url"] == ""
|
||||
|
||||
# Verify database queries
|
||||
# Should be called twice: once for MessageFile, once for UploadFile
|
||||
assert mock_session.scalars.call_count == 2
|
||||
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
|
||||
|
||||
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
|
||||
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_remote]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_remote.id
|
||||
assert file_dict["filename"] == "image.jpg"
|
||||
assert file_dict["url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["extension"] == ".jpg"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
|
||||
assert file_dict["remote_url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["upload_file_id"] == mock_message_file_remote.id
|
||||
|
||||
# Verify only one query for message_files is made
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "https://example.com/tool_file.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["url"] == "https://example.com/tool_file.png"
|
||||
assert file_dict["filename"] == "tool_file.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with local path."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_123.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
|
||||
assert file_dict["filename"] == "tool_file_123.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
# Verify tool file signing was called
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
|
||||
|
||||
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_abc.verylongextension"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
mock_sign_tool.return_value = "https://example.com/signed.bin"
|
||||
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
assert result.files is not None
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["extension"] == ".bin"
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
|
||||
|
||||
def test_message_end_with_multiple_files(
|
||||
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
|
||||
):
|
||||
"""Test that files array contains all MessageFile records when multiple exist."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 2
|
||||
|
||||
# Verify both files are present
|
||||
file_ids = [f["related_id"] for f in result.files]
|
||||
assert mock_message_file_local.id in file_ids
|
||||
assert mock_message_file_remote.id in file_ids
|
||||
|
||||
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
|
||||
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query) - returns empty list (not found)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [] # UploadFile not found
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/fallback-url" in file_dict["url"]
|
||||
# Verify fallback URL was generated using upload_file_id from message_file
|
||||
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))
|
||||
@@ -0,0 +1,84 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType
|
||||
from models import Account, WorkflowRun
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
|
||||
|
||||
def _build_repository_with_mocked_session(session: MagicMock) -> SQLAlchemyWorkflowExecutionRepository:
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
real_session_factory = sessionmaker(bind=engine, expire_on_commit=False)
|
||||
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = str(uuid4())
|
||||
user.current_tenant_id = str(uuid4())
|
||||
|
||||
repository = SQLAlchemyWorkflowExecutionRepository(
|
||||
session_factory=real_session_factory,
|
||||
user=user,
|
||||
app_id="app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = False
|
||||
repository._session_factory = MagicMock(return_value=session_context)
|
||||
return repository
|
||||
|
||||
|
||||
def _build_execution(*, execution_id: str, started_at: datetime) -> WorkflowExecution:
|
||||
return WorkflowExecution.new(
|
||||
id_=execution_id,
|
||||
workflow_id="workflow-id",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_version="1.0.0",
|
||||
graph={"nodes": [], "edges": []},
|
||||
inputs={"query": "hello"},
|
||||
started_at=started_at,
|
||||
)
|
||||
|
||||
|
||||
def test_save_uses_execution_started_at_when_record_does_not_exist():
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
started_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
execution = _build_execution(execution_id=str(uuid4()), started_at=started_at)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == started_at
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_save_preserves_existing_created_at_when_record_already_exists():
|
||||
session = MagicMock()
|
||||
repository = _build_repository_with_mocked_session(session)
|
||||
|
||||
execution_id = str(uuid4())
|
||||
existing_created_at = datetime(2026, 1, 1, 12, 0, 0)
|
||||
existing_run = WorkflowRun()
|
||||
existing_run.id = execution_id
|
||||
existing_run.tenant_id = repository._tenant_id
|
||||
existing_run.created_at = existing_created_at
|
||||
session.get.return_value = existing_run
|
||||
|
||||
execution = _build_execution(
|
||||
execution_id=execution_id,
|
||||
started_at=datetime(2026, 1, 1, 12, 30, 0),
|
||||
)
|
||||
|
||||
repository.save(execution)
|
||||
|
||||
saved_model = session.merge.call_args.args[0]
|
||||
assert saved_model.created_at == existing_created_at
|
||||
session.commit.assert_called_once()
|
||||
@@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool
|
||||
from dify_graph.variables.variables import StringVariable
|
||||
|
||||
|
||||
class StubCoordinator:
|
||||
@@ -278,3 +280,17 @@ class TestGraphRuntimeState:
|
||||
assert restored_execution.started is True
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
|
||||
def test_snapshot_restore_preserves_updated_conversation_variable(self):
|
||||
variable_pool = VariablePool(
|
||||
conversation_variables=[StringVariable(name="session_name", value="before")],
|
||||
)
|
||||
variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after")
|
||||
|
||||
state = GraphRuntimeState(variable_pool=variable_pool, start_at=time())
|
||||
snapshot = state.dumps()
|
||||
restored = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name"))
|
||||
assert restored_value is not None
|
||||
assert restored_value.value == "after"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user