Compare commits

..

1 Commits

Author SHA1 Message Date
jyong
bea428e308 evaluations 2026-01-30 17:35:36 +08:00
240 changed files with 4790 additions and 8019 deletions

View File

@@ -480,4 +480,4 @@ const useButtonState = () => {
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/docs/test.md` - Testing specification
- `web/testing/testing.md` - Testing specification

View File

@@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
@@ -309,7 +309,7 @@ For more detailed information, refer to:
### Primary Specification (MUST follow)
- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document.
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase

View File

@@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|

3
.github/CODEOWNERS vendored
View File

@@ -9,9 +9,6 @@
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Agents
/.agents/skills/ @hyoban
# Docs
/docs/ @crazywoola

View File

@@ -72,7 +72,6 @@ jobs:
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api pytest \
-n auto \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \

View File

@@ -47,9 +47,13 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: uv run --directory api --dev lint-imports
- name: Run Type Checks
- name: Run Basedpyright Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: make type-check
run: dev/basedpyright-check
- name: Run Mypy Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Backend Workflow
@@ -18,7 +18,36 @@ The codebase is split into:
## Frontend Workflow
- Read `web/AGENTS.md` for details
```bash
cd web
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
### Frontend Linting
ESLint is used for frontend code quality. Available commands:
```bash
# Lint all files (report only)
pnpm lint
# Lint and auto-fix issues
pnpm lint:fix
# Lint specific files or directories
pnpm lint:fix app/components/base/button/
pnpm lint:fix app/components/base/button/index.tsx
# Lint quietly (errors only, no warnings)
pnpm lint:quiet
# Check code complexity
pnpm lint:complexity
```
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
## Testing & Quality Practices

View File

@@ -77,7 +77,7 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend

View File

@@ -68,11 +68,9 @@ lint:
@echo "✅ Linting complete"
type-check:
@echo "📝 Running type checks (basedpyright + mypy + ty)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@cd api && uv run ty check
@echo "✅ Type checks complete"
@echo "📝 Running type check with basedpyright..."
@uv run --directory api --dev basedpyright
@echo "✅ Type check complete"
test:
@echo "🧪 Running backend unit tests..."
@@ -80,7 +78,7 @@ test:
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi
@echo "✅ Tests complete"
@@ -132,7 +130,7 @@ help:
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, mypy, ty)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@@ -617,7 +617,6 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
@@ -717,3 +716,4 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000

View File

@@ -53,7 +53,6 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
]
@@ -89,7 +88,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
]
[lint.per-file-ignores]
@@ -111,20 +109,10 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -1,12 +1,4 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from celery import Celery
celery: Celery
def is_db_command() -> bool:
@@ -31,7 +23,7 @@ else:
from app_factory import create_app
app = create_app()
celery = cast("Celery", app.extensions["celery"])
celery = app.extensions["celery"]
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5001)

View File

@@ -149,7 +149,7 @@ def initialize_extensions(app: DifyApp):
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app() -> DifyApp:
def create_migrations_app():
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate

View File

@@ -1450,58 +1450,54 @@ def clear_orphaned_file_records(force: bool):
all_ids_in_tables = []
for ids_table in ids_tables:
query = ""
match ids_table["type"]:
case "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
if ids_table["type"] == "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
)
c = ids_table["column"]
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
case "text":
t = ids_table["table"]
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
fg="white",
)
)
query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case _:
pass
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e:
@@ -1741,18 +1737,59 @@ def file_usage(
if src_filter != src:
continue
match ids_table["type"]:
case "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
@@ -1775,50 +1812,6 @@ def file_usage(
)
total_count += 1
case "text" | "json":
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
case _:
pass
# Output results
if output_json:
result = {

View File

@@ -245,7 +245,7 @@ class PluginConfig(BaseSettings):
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching plugin model schemas in Redis",
default=60 * 60,
default=24 * 60 * 60,
)

View File

@@ -115,6 +115,12 @@ from .explore import (
trial,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import snippet controllers
from .snippets import snippet_workflow
# Import tag controllers
from .tag import tags
@@ -128,6 +134,7 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@@ -165,6 +172,7 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@@ -197,6 +205,8 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippets",
"spec",
"statistic",
"tags",

View File

@@ -1,11 +1,10 @@
from typing import Any, Literal
from flask import abort, make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@@ -17,11 +16,9 @@ from controllers.console.wraps import (
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
Annotation,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
)
from libs.helper import uuid_value
from libs.login import login_required
@@ -92,14 +89,6 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@@ -118,11 +107,10 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -213,33 +201,33 @@ class AnnotationApi(Resource):
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json"), 200
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
return annotation
@setup_required
@login_required
@@ -276,7 +264,7 @@ class AnnotationExportApi(Resource):
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.models[AnnotationExportList.__name__],
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -286,8 +274,7 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
@@ -302,7 +289,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@@ -311,6 +298,7 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
@@ -318,7 +306,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
return annotation
@setup_required
@login_required
@@ -426,7 +414,14 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.models[AnnotationHitHistoryList.__name__],
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -441,14 +436,11 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True
)
response = AnnotationHitHistoryList(
data=history_models,
has_more=len(annotation_hit_history_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response

View File

@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
@@ -34,6 +33,7 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
@@ -47,11 +47,13 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
class AudioTranscriptResponse(BaseModel):
text: str = Field(description="Transcribed text from audio")
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@@ -62,7 +64,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response(
200,
"Audio transcription successful",
console_ns.models[AudioTranscriptResponse.__name__],
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
)
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")

View File

@@ -508,19 +508,16 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
@@ -11,12 +12,10 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@@ -27,13 +26,28 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
@@ -50,7 +64,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")
@@ -69,7 +82,12 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -100,7 +118,9 @@ class RuleCodeGenerateApi(Resource):
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -132,7 +152,8 @@ class RuleStructuredOutputGenerateApi(Resource):
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -183,29 +204,23 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
)
case _:
return {"error": f"invalid node type: {node_type}"}

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@@ -36,6 +35,7 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
@@ -90,22 +90,13 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value")
class AnnotationCountResponse(BaseModel):
count: int = Field(description="Number of annotations")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -240,7 +231,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
@@ -365,7 +356,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response(
200,
"Annotation count retrieved successfully",
console_ns.models[AnnotationCountResponse.__name__],
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
)
@get_app_model
@setup_required
@@ -385,7 +376,9 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response(
200,
"Suggested questions retrieved successfully",
console_ns.models[SuggestedQuestionsResponse.__name__],
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
)
@console_ns.response(404, "Message or conversation not found")
@setup_required
@@ -435,7 +428,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict())
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService

View File

@@ -2,11 +2,9 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, fields
from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@@ -16,26 +14,6 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
@@ -56,7 +34,10 @@ class OAuthDataSource(Resource):
@console_ns.response(
200,
"Authorization URL or internal setup success",
console_ns.models[OAuthDataSourceResponse.__name__],
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@@ -120,7 +101,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response(
200,
"Data source binding success",
console_ns.models[OAuthDataSourceBindingResponse.__name__],
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
@@ -152,7 +133,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response(
200,
"Data source sync success",
console_ns.models[OAuthDataSourceSyncResponse.__name__],
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required

View File

@@ -2,11 +2,10 @@ import base64
import secrets
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
@@ -49,31 +48,8 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
@@ -84,7 +60,14 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response(
200,
"Email sent successfully",
console_ns.models[ForgotPasswordEmailResponse.__name__],
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
)
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@@ -123,7 +106,14 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response(
200,
"Code verified successfully",
console_ns.models[ForgotPasswordCheckResponse.__name__],
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
)
@console_ns.response(400, "Invalid code or token")
@setup_required
@@ -173,7 +163,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response(
200,
"Password reset successfully",
console_ns.models[ForgotPasswordResetResponse.__name__],
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required

View File

@@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, Literal, cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@@ -157,8 +157,9 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
@@ -166,24 +167,23 @@ class DataSourceApi(Resource):
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200

View File

@@ -576,62 +576,63 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None:
raise NotFound("File not found.")
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
if file_detail is None:
raise NotFound("File not found.")
case _:
raise ValueError("Data source type not support")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
@@ -953,24 +954,23 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
match action:
case "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
return {"result": "success"}, 200
@@ -1339,18 +1339,6 @@ class DocumentGenerateSummaryApi(Resource):
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries

View File

@@ -126,11 +126,10 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@@ -1,9 +1,10 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -37,7 +38,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@@ -109,7 +110,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
@@ -120,10 +121,6 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@@ -138,7 +135,6 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@@ -979,8 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@@ -0,0 +1 @@
# Evaluation controller module

View File

@@ -0,0 +1,288 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App
from models.snippet import CustomizedSnippet
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Valid evaluation target types
EVALUATE_TARGET_TYPES = {"app", "snippets"}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
customized_matrix_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
condition_fields = {
"name": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgement_conditions_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"customized_matrix": fields.Nested(
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
allow_null=True,
),
"judgement_conditions": fields.Nested(
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
def get_evaluation_target(view_func: Callable[P, R]):
"""
Decorator to resolve polymorphic evaluation target (app or snippet).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet] | None = None
if target_type == "app":
target = (
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
)
elif target_type == "snippets":
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template download URL generated successfully")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates a download URL for the evaluation dataset template
based on the target type (app or snippets).
"""
# TODO: Implement actual template generation logic
# This is a placeholder implementation
return {
"download_url": f"/api/evaluation/{target_type}/{target.id}/template.csv",
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation details for the target.
Returns evaluation configuration including model settings,
customized matrix, and judgement conditions.
"""
# TODO: Implement actual evaluation detail retrieval
# This is a placeholder implementation
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"customized_matrix": None,
"judgement_conditions": None,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully", evaluation_log_list_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get offline evaluation logs for the target.
Returns a list of evaluation runs with test files,
result files, and version information.
"""
# TODO: Implement actual evaluation logs retrieval
# This is a placeholder implementation
return {
"data": [],
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Returns file information and download URL for the specified file.
"""
file_id = str(file_id)
# TODO: Implement actual file download logic
# This is a placeholder implementation
return {
"created_at": None,
"created_by": None,
"test_file": None,
"result_file": None,
"version": None,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
# TODO: Implement actual version detail retrieval
# For now, return the current graph if available
graph = {}
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}

View File

@@ -9,7 +9,7 @@ import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.common.schema import get_or_create_model
from controllers.console import api
from controllers.console import api, console_ns
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
@@ -51,7 +51,7 @@ from fields.app_fields import (
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.member_fields import simple_account_fields
from fields.member_fields import build_simple_account_model
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
@@ -103,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
simple_account_model = build_simple_account_model(console_ns)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)

View File

@@ -1,74 +1,87 @@
import os
from typing import Literal
from flask import session
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.fastopenapi import console_router
from extensions.ext_database import db
from models.model import DifySetup
from services.account_service import TenantService
from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30, description="Initialization password")
password: str = Field(..., max_length=30)
class InitStatusResponse(BaseModel):
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
class InitValidateResponse(BaseModel):
result: str = Field(description="Operation result", examples=["success"])
@console_router.get(
"/init",
response_model=InitStatusResponse,
tags=["console"],
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def get_init_status() -> InitStatusResponse:
"""Get initialization validation status."""
init_status = get_init_validate_status()
if init_status:
return InitStatusResponse(status="finished")
return InitStatusResponse(status="not_started")
@console_router.post(
"/init",
response_model=InitValidateResponse,
tags=["console"],
status_code=201,
)
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
@console_ns.route("/init")
class InitValidateAPI(Resource):
@console_ns.doc("get_init_status")
@console_ns.doc(description="Get initialization validation status")
@console_ns.response(
200,
"Success",
model=console_ns.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
if payload.password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
session["is_init_validated"] = True
return InitValidateResponse(result="success")
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
def get_init_validate_status() -> bool:
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True

View File

@@ -1,6 +1,7 @@
import urllib.parse
import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field
import services
@@ -10,7 +11,7 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.fastopenapi import console_router
from controllers.common.schema import register_schema_models
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
@@ -18,74 +19,84 @@ from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
@console_router.get(
"/remote-files/<path:url>",
response_model=RemoteFileInfo,
tags=["console"],
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
@console_router.post(
"/remote-files/upload",
response_model=FileWithSignedUrl,
tags=["console"],
status_code=201,
)
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url
@console_ns.route("/remote-files/upload")
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
def post(self):
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)
file_info = helpers.guess_file_info_from_response(resp)
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
try:
user, _ = current_account_with_tenant()
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,
user=user,
source_url=url,
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload.model_dump(mode="json"), 201

View File

@@ -0,0 +1,75 @@
from typing import Any, Literal
from pydantic import BaseModel, Field
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
input_variables: list[dict[str, Any]] | None = None
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None

View File

@@ -0,0 +1,306 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import workflow_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftSyncPayload,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from factories import variable_factory
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import WorkflowHashNotEqualError
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
WorkflowRunQuery,
PublishWorkflowPayload,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
input_variables=payload.input_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}

View File

@@ -1,27 +1,17 @@
from typing import Literal
from flask import request
from flask_restx import Namespace, Resource, fields, marshal_with
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from services.tag_service import TagService
dataset_tag_fields = {
"id": fields.String,
"name": fields.String,
"type": fields.String,
"binding_count": fields.String,
}
def build_dataset_tag_fields(api_or_ns: Namespace):
return api_or_ns.model("DataSetTag", dataset_tag_fields)
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
@@ -38,7 +37,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import Account as AccountResponse
from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
@@ -171,12 +170,6 @@ reg(ChangeEmailSendPayload)
reg(ChangeEmailValidityPayload)
reg(ChangeEmailResetPayload)
reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
integrate_fields = {
"provider": fields.String,
@@ -243,11 +236,11 @@ class AccountProfileApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
return _serialize_account(current_user)
return current_user
@console_ns.route("/account/name")
@@ -256,14 +249,14 @@ class AccountNameApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/avatar")
@@ -272,7 +265,7 @@ class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -280,7 +273,7 @@ class AccountAvatarApi(Resource):
updated_account = AccountService.update_account(current_user, avatar=args.avatar)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/interface-language")
@@ -289,7 +282,7 @@ class AccountInterfaceLanguageApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -297,7 +290,7 @@ class AccountInterfaceLanguageApi(Resource):
updated_account = AccountService.update_account(current_user, interface_language=args.interface_language)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/interface-theme")
@@ -306,7 +299,7 @@ class AccountInterfaceThemeApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -314,7 +307,7 @@ class AccountInterfaceThemeApi(Resource):
updated_account = AccountService.update_account(current_user, interface_theme=args.interface_theme)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/timezone")
@@ -323,7 +316,7 @@ class AccountTimezoneApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -331,7 +324,7 @@ class AccountTimezoneApi(Resource):
updated_account = AccountService.update_account(current_user, timezone=args.timezone)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/password")
@@ -340,7 +333,7 @@ class AccountPasswordApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
@@ -351,7 +344,7 @@ class AccountPasswordApi(Resource):
except ServiceCurrentPasswordIncorrectError:
raise CurrentPasswordIncorrectError()
return _serialize_account(current_user)
return {"result": "success"}
@console_ns.route("/account/integrates")
@@ -627,7 +620,7 @@ class ChangeEmailResetApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@marshal_with(account_fields)
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
@@ -656,7 +649,7 @@ class ChangeEmailResetApi(Resource):
email=normalized_new_email,
)
return _serialize_account(updated_account)
return updated_account
@console_ns.route("/account/change-email/check-email-unique")

View File

@@ -1,10 +1,9 @@
from typing import Any
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -39,53 +38,15 @@ class EndpointListForPluginQuery(EndpointListQuery):
plugin_id: str
class EndpointCreateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class PluginEndpointListResponse(BaseModel):
endpoints: list[dict[str, Any]] = Field(description="Endpoint information")
class EndpointDeleteResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointUpdateResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointEnableResponse(BaseModel):
success: bool = Field(description="Operation success")
class EndpointDisableResponse(BaseModel):
success: bool = Field(description="Operation success")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
register_schema_models(
console_ns,
EndpointCreatePayload,
EndpointIdPayload,
EndpointUpdatePayload,
EndpointListQuery,
EndpointListForPluginQuery,
EndpointCreateResponse,
EndpointListResponse,
PluginEndpointListResponse,
EndpointDeleteResponse,
EndpointUpdateResponse,
EndpointEnableResponse,
EndpointDisableResponse,
)
reg(EndpointCreatePayload)
reg(EndpointIdPayload)
reg(EndpointUpdatePayload)
reg(EndpointListQuery)
reg(EndpointListForPluginQuery)
@console_ns.route("/workspaces/current/endpoints/create")
@@ -96,7 +57,7 @@ class EndpointCreateApi(Resource):
@console_ns.response(
200,
"Endpoint created successfully",
console_ns.models[EndpointCreateResponse.__name__],
console_ns.model("EndpointCreateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -130,7 +91,9 @@ class EndpointListApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.models[EndpointListResponse.__name__],
console_ns.model(
"EndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
)
@setup_required
@login_required
@@ -163,7 +126,9 @@ class EndpointListForSinglePluginApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.models[PluginEndpointListResponse.__name__],
console_ns.model(
"PluginEndpointListResponse", {"endpoints": fields.List(fields.Raw(description="Endpoint information"))}
),
)
@setup_required
@login_required
@@ -198,7 +163,7 @@ class EndpointDeleteApi(Resource):
@console_ns.response(
200,
"Endpoint deleted successfully",
console_ns.models[EndpointDeleteResponse.__name__],
console_ns.model("EndpointDeleteResponse", {"success": fields.Boolean(description="Operation success")}),
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -225,7 +190,7 @@ class EndpointUpdateApi(Resource):
@console_ns.response(
200,
"Endpoint updated successfully",
console_ns.models[EndpointUpdateResponse.__name__],
console_ns.model("EndpointUpdateResponse", {"success": fields.Boolean(description="Operation success")}),
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -256,7 +221,7 @@ class EndpointEnableApi(Resource):
@console_ns.response(
200,
"Endpoint enabled successfully",
console_ns.models[EndpointEnableResponse.__name__],
console_ns.model("EndpointEnableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@console_ns.response(403, "Admin privileges required")
@setup_required
@@ -283,7 +248,7 @@ class EndpointDisableApi(Resource):
@console_ns.response(
200,
"Endpoint disabled successfully",
console_ns.models[EndpointDisableResponse.__name__],
console_ns.model("EndpointDisableResponse", {"success": fields.Boolean(description="Operation success")}),
)
@console_ns.response(403, "Admin privileges required")
@setup_required

View File

@@ -1,12 +1,12 @@
from urllib import parse
from flask import abort, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
import services
from configs import dify_config
from controllers.common.schema import register_enum_models, register_schema_models
from controllers.common.schema import get_or_create_model, register_enum_models
from controllers.console import console_ns
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
@@ -25,7 +25,7 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import AccountWithRole, AccountWithRoleList
from fields.member_fields import account_with_role_fields, account_with_role_list_fields
from libs.helper import extract_remote_ip
from libs.login import current_account_with_tenant, login_required
from models.account import Account, TenantAccountRole
@@ -69,7 +69,12 @@ reg(OwnerTransferEmailPayload)
reg(OwnerTransferCheckPayload)
reg(OwnerTransferPayload)
register_enum_models(console_ns, TenantAccountRole)
register_schema_models(console_ns, AccountWithRole, AccountWithRoleList)
account_with_role_model = get_or_create_model("AccountWithRole", account_with_role_fields)
account_with_role_list_fields_copy = account_with_role_list_fields.copy()
account_with_role_list_fields_copy["accounts"] = fields.List(fields.Nested(account_with_role_model))
account_with_role_list_model = get_or_create_model("AccountWithRoleList", account_with_role_list_fields_copy)
@console_ns.route("/workspaces/current/members")
@@ -79,15 +84,13 @@ class MemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
@marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_tenant_members(current_user.current_tenant)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/invite-email")
@@ -232,15 +235,13 @@ class DatasetOperatorMemberListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountWithRoleList.__name__])
@marshal_with(account_with_role_list_model)
def get(self):
current_user, _ = current_account_with_tenant()
if not current_user.current_tenant:
raise ValueError("No current tenant")
members = TenantService.get_dataset_operator_members(current_user.current_tenant)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = AccountWithRoleList(accounts=member_models)
return response.model_dump(mode="json"), 200
return {"result": "success", "accounts": members}, 200
@console_ns.route("/workspaces/current/members/send-owner-transfer-confirm-email")

View File

@@ -0,0 +1,202 @@
import logging
from flask import request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
graph=payload.graph,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204

File diff suppressed because it is too large Load Diff

View File

@@ -1,16 +1,16 @@
from typing import Literal
from flask import request
from flask_restx import Resource
from flask_restx import Namespace, Resource, fields
from flask_restx.api import HTTPStatus
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console.wraps import edit_permission_required
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from fields.annotation_fields import annotation_fields, build_annotation_model
from models.model import App
from services.annotation_service import AppAnnotationService
@@ -26,9 +26,7 @@ class AnnotationReplyActionPayload(BaseModel):
embedding_model_name: str = Field(description="Embedding model name")
register_schema_models(
service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload, Annotation, AnnotationList
)
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
@service_api_ns.route("/apps/annotation-reply/<string:action>")
@@ -47,11 +45,10 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@@ -85,6 +82,23 @@ class AnnotationReplyActionStatusApi(Resource):
return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200
# Define annotation list response model
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
"has_more": fields.Boolean,
"limit": fields.Integer,
"total": fields.Integer,
"page": fields.Integer,
}
def build_annotation_list_model(api_or_ns: Namespace):
"""Build the annotation list model for the API or Namespace."""
copied_annotation_list_fields = annotation_list_fields.copy()
copied_annotation_list_fields["data"] = fields.List(fields.Nested(build_annotation_model(api_or_ns)))
return api_or_ns.model("AnnotationList", copied_annotation_list_fields)
@service_api_ns.route("/apps/annotations")
class AnnotationListApi(Resource):
@service_api_ns.doc("list_annotations")
@@ -95,12 +109,8 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
200,
"Annotations retrieved successfully",
service_api_ns.models[AnnotationList.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_list_model(service_api_ns))
def get(self, app_model: App):
"""List annotations for the application."""
page = request.args.get("page", default=1, type=int)
@@ -108,15 +118,13 @@ class AnnotationListApi(Resource):
keyword = request.args.get("keyword", default="", type=str)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_model.id, page, limit, keyword)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")
return {
"data": annotation_list,
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
@service_api_ns.doc("create_annotation")
@@ -127,18 +135,13 @@ class AnnotationListApi(Resource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.response(
HTTPStatus.CREATED,
"Annotation created successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
def post(self, app_model: App):
"""Create a new annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
return annotation, 201
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
@@ -155,19 +158,14 @@ class AnnotationUpdateDeleteApi(Resource):
404: "Annotation not found",
}
)
@service_api_ns.response(
200,
"Annotation updated successfully",
service_api_ns.models[Annotation.__name__],
)
@validate_app_token
@edit_permission_required
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
return annotation
@service_api_ns.doc("delete_annotation")
@service_api_ns.doc(description="Delete an annotation")

View File

@@ -30,7 +30,6 @@ from core.errors.error import (
from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.app_generate_service import AppGenerateService
from services.app_task_service import AppTaskService
@@ -53,7 +52,7 @@ class ChatRequestPayload(BaseModel):
query: str
files: list[dict[str, Any]] | None = None
response_mode: Literal["blocking", "streaming"] | None = None
conversation_id: UUIDStrOrEmpty | None = Field(default=None, description="Conversation UUID")
conversation_id: str | None = Field(default=None, description="Conversation UUID")
retriever_from: str = Field(default="dev")
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")

View File

@@ -1,4 +1,5 @@
from typing import Any, Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@@ -22,13 +23,12 @@ from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
build_conversation_variable_model,
)
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
class ConversationListQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last conversation ID for pagination")
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
default="-updated_at", description="Sort order for conversations"
@@ -48,7 +48,7 @@ class ConversationRenamePayload(BaseModel):
class ConversationVariablesQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
variable_name: str | None = Field(
default=None, description="Filter variables by name", min_length=1, max_length=255

View File

@@ -1,5 +1,6 @@
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Resource
@@ -14,7 +15,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
conversation_id: UUID
first_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")

View File

@@ -17,7 +17,7 @@ from controllers.service_api.wraps import (
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import DataSetTag
from fields.tag_fields import build_dataset_tag_fields
from libs.login import current_user
from models.account import Account
from models.dataset import DatasetPermissionEnum
@@ -114,7 +114,6 @@ register_schema_models(
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
DataSetTag,
)
@@ -481,14 +480,15 @@ class DatasetTagsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def get(self, _):
"""Get all knowledge type tags."""
assert isinstance(current_user, Account)
cid = current_user.current_tenant_id
assert cid is not None
tags = TagService.get_tags("knowledge", cid)
tag_models = TypeAdapter(list[DataSetTag]).validate_python(tags, from_attributes=True)
return [tag.model_dump(mode="json") for tag in tag_models], 200
return tags, 200
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
@service_api_ns.doc("create_dataset_tag")
@@ -500,6 +500,7 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def post(self, _):
"""Add a knowledge type tag."""
assert isinstance(current_user, Account)
@@ -509,9 +510,7 @@ class DatasetTagsApi(DatasetApiResource):
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
).model_dump(mode="json")
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
@@ -524,6 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
403: "Forbidden - insufficient permissions",
}
)
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
def patch(self, _):
assert isinstance(current_user, Account)
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@@ -536,9 +536,8 @@ class DatasetTagsApi(DatasetApiResource):
binding_count = TagService.get_tag_binding_count(tag_id)
response = DataSetTag.model_validate(
{"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
).model_dump(mode="json")
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count}
return response, 200
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])

View File

@@ -1,10 +1,7 @@
from controllers.common.schema import register_schema_model
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
register_schema_model(service_api_ns, HitTestingPayload)
@service_api_ns.route("/datasets/<uuid:dataset_id>/hit-testing", "/datasets/<uuid:dataset_id>/retrieve")
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@@ -18,7 +15,6 @@ class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
404: "Dataset not found",
}
)
@service_api_ns.expect(service_api_ns.models[HitTestingPayload.__name__])
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Perform hit testing on a dataset.

View File

@@ -168,11 +168,10 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@@ -73,14 +73,14 @@ def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: Fe
# If caller needs end-user context, attach EndUser to current_user
if fetch_user_arg:
user_id = None
match fetch_user_arg.fetch_from:
case WhereisUserArg.QUERY:
user_id = request.args.get("user")
case WhereisUserArg.JSON:
user_id = request.get_json().get("user")
case WhereisUserArg.FORM:
user_id = request.form.get("user")
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
user_id = request.args.get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
user_id = request.get_json().get("user")
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
user_id = request.form.get("user")
else:
user_id = None
if not user_id and fetch_user_arg.required:
raise ValueError("Arg user must be provided.")

View File

@@ -14,17 +14,16 @@ class AgentConfigManager:
agent_dict = config.get("agent_mode", {})
agent_strategy = agent_dict.get("strategy", "cot")
match agent_strategy:
case "function_call":
if agent_strategy == "function_call":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
elif agent_strategy in {"cot", "react"}:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
case "cot" | "react":
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
case _:
# old configs, try to detect default strategy
if config["model"]["provider"] == "openai":
strategy = AgentEntity.Strategy.FUNCTION_CALLING
else:
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
agent_tools = []
for tool in agent_dict.get("tools", []):

View File

@@ -250,7 +250,7 @@ class WorkflowResponseConverter:
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_id,
status=status,
status=status.value,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
@@ -340,13 +340,13 @@ class WorkflowResponseConverter:
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED
status = WorkflowNodeExecutionStatus.FAILED.value
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION
status = WorkflowNodeExecutionStatus.EXCEPTION.value
error_message = event.error
return NodeFinishStreamResponse(
@@ -413,7 +413,7 @@ class WorkflowResponseConverter:
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY,
status=WorkflowNodeExecutionStatus.RETRY.value,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,

View File

@@ -120,7 +120,7 @@ class PipelineGenerator(BaseAppGenerator):
raise ValueError("Pipeline dataset is required")
inputs: Mapping[str, Any] = args["inputs"]
start_node_id: str = args["start_node_id"]
datasource_type = DatasourceProviderType(args["datasource_type"])
datasource_type: str = args["datasource_type"]
datasource_info_list: list[Mapping[str, Any]] = self._format_datasource_info_list(
datasource_type, args["datasource_info_list"], pipeline, workflow, start_node_id, user
)
@@ -660,7 +660,7 @@ class PipelineGenerator(BaseAppGenerator):
tenant_id: str,
dataset_id: str,
built_in_field_enabled: bool,
datasource_type: DatasourceProviderType,
datasource_type: str,
datasource_info: Mapping[str, Any],
created_from: str,
position: int,
@@ -668,17 +668,17 @@ class PipelineGenerator(BaseAppGenerator):
batch: str,
document_form: str,
):
match datasource_type:
case DatasourceProviderType.LOCAL_FILE:
name = datasource_info.get("name", "untitled")
case DatasourceProviderType.ONLINE_DOCUMENT:
name = datasource_info.get("page", {}).get("page_name", "untitled")
case DatasourceProviderType.WEBSITE_CRAWL:
name = datasource_info.get("title", "untitled")
case DatasourceProviderType.ONLINE_DRIVE:
name = datasource_info.get("name", "untitled")
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
if datasource_type == "local_file":
name = datasource_info.get("name", "untitled")
elif datasource_type == "online_document":
name = datasource_info.get("page", {}).get("page_name", "untitled")
elif datasource_type == "website_crawl":
name = datasource_info.get("title", "untitled")
elif datasource_type == "online_drive":
name = datasource_info.get("name", "untitled")
else:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
@@ -706,7 +706,7 @@ class PipelineGenerator(BaseAppGenerator):
def _format_datasource_info_list(
self,
datasource_type: DatasourceProviderType,
datasource_type: str,
datasource_info_list: list[Mapping[str, Any]],
pipeline: Pipeline,
workflow: Workflow,
@@ -716,7 +716,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
Format datasource info list.
"""
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
if datasource_type == "online_drive":
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class AnnotationReplyAccount(BaseModel):
@@ -223,7 +223,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
id: str
workflow_id: str
status: WorkflowExecutionStatus
status: str
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float
@@ -311,7 +311,7 @@ class NodeFinishStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: WorkflowNodeExecutionStatus
status: str
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -375,7 +375,7 @@ class NodeRetryStreamResponse(StreamResponse):
process_data_truncated: bool = False
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
status: str
error: str | None = None
elapsed_time: float
execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None
@@ -719,7 +719,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
id: str
workflow_id: str
status: WorkflowExecutionStatus
status: str
outputs: Mapping[str, Any] | None = None
error: str | None = None
elapsed_time: float

View File

@@ -4,14 +4,13 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.file.file_manager import file_manager
from core.file import file_manager
from core.helper import ssrf_proxy
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
from core.workflow.graph.graph import NodeFactory
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
@@ -23,6 +22,7 @@ from core.workflow.nodes.template_transform.template_renderer import (
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
@@ -47,9 +47,9 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
http_request_http_client: HttpClientProtocol | None = None,
http_request_http_client: HttpClientProtocol = ssrf_proxy,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
http_request_file_manager: FileManagerProtocol = file_manager,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
@@ -68,12 +68,12 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_http_client = http_request_http_client
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._http_request_file_manager = http_request_file_manager
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data using the traditional mapping.
@@ -82,14 +82,23 @@ class DifyNodeFactory(NodeFactory):
:raises ValueError: if node type is unknown or configuration is invalid
"""
# Get node_id from config
node_id = node_config["id"]
node_id = node_config.get("id")
if not is_str(node_id):
raise ValueError("Node config missing id")
# Get node type from config
node_data = node_config["data"]
node_data = node_config.get("data", {})
if not is_str_dict(node_data):
raise ValueError(f"Node {node_id} missing data information")
node_type_str = node_data.get("type")
if not is_str(node_type_str):
raise ValueError(f"Node {node_id} missing or invalid type information")
try:
node_type = NodeType(node_data["type"])
node_type = NodeType(node_type_str)
except ValueError:
raise ValueError(f"Unknown node type: {node_data['type']}")
raise ValueError(f"Unknown node type: {node_type_str}")
# Get node class
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)

View File

@@ -168,18 +168,3 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
class FileManager:
"""
Adapter exposing file manager helpers behind FileManagerProtocol.
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
where a protocol-typed file manager is expected.
"""
def download(self, f: File, /) -> bytes:
return download(f)
file_manager = FileManager()

View File

@@ -47,16 +47,15 @@ class CodeNodeProvider(BaseModel, ABC):
@classmethod
def get_default_config(cls) -> DefaultConfig:
variables: list[VariableConfig] = [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
]
outputs: dict[str, OutputConfig] = {"result": {"type": "string", "children": None}}
config: CodeConfig = {
"variables": variables,
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": outputs,
return {
"type": "code",
"config": {
"variables": [
{"variable": "arg1", "value_selector": []},
{"variable": "arg2", "value_selector": []},
],
"code_language": cls.get_language(),
"code": cls.get_default_code(),
"outputs": {"result": {"type": "string", "children": None}},
},
}
return {"type": "code", "config": config}

View File

@@ -230,41 +230,3 @@ def delete(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any)
def head(url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
class SSRFProxy:
"""
Adapter exposing SSRF-protected HTTP helpers behind HttpClientProtocol.
This is intentionally a thin wrapper over the existing module-level functions so callers can inject it
where a protocol-typed HTTP client is expected.
"""
@property
def max_retries_exceeded_error(self) -> type[Exception]:
return max_retries_exceeded_error
@property
def request_error(self) -> type[Exception]:
return request_error
def get(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return get(url=url, max_retries=max_retries, **kwargs)
def head(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return head(url=url, max_retries=max_retries, **kwargs)
def post(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return post(url=url, max_retries=max_retries, **kwargs)
def put(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return put(url=url, max_retries=max_retries, **kwargs)
def delete(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return delete(url=url, max_retries=max_retries, **kwargs)
def patch(self, url: str, max_retries: int = SSRF_DEFAULT_MAX_RETRIES, **kwargs: Any) -> httpx.Response:
return patch(url=url, max_retries=max_retries, **kwargs)
ssrf_proxy = SSRFProxy()

View File

@@ -369,78 +369,77 @@ class IndexingRunner:
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(
tenant_id, preview_texts, summary_index_setting, doc_language
)
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
) -> list[Document]:
# load file
if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
return []
data_source_info = dataset_document.data_source_info_dict
text_docs = []
match dataset_document.data_source_type:
case "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
if file_detail:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
datasource_type=DatasourceType.FILE,
upload_file=file_detail,
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
case _:
return []
elif dataset_document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_workspace_id" not in data_source_info
or "notion_page_id" not in data_source_info
):
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
elif dataset_document.data_source_type == "website_crawl":
if (
not data_source_info
or "provider" not in data_source_info
or "url" not in data_source_info
or "job_id" not in data_source_info
):
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
# update document status to splitting
self._update_document_index_status(
document_id=dataset_document.id,

View File

@@ -1,20 +0,0 @@
"""Shared payload models for LLM generator helpers and controllers."""
from pydantic import BaseModel, Field
from core.app.app_config.entities import ModelConfig
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")

View File

@@ -6,8 +6,6 @@ from typing import Protocol, cast
import json_repair
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -153,19 +151,19 @@ class LLMGenerator:
return questions
@classmethod
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool):
output_parser = RuleConfigGeneratorOutputParser()
error = ""
error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = args.model_config_data.completion_params
if args.no_variable:
model_parameters = model_config.get("completion_params", {})
if no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -177,8 +175,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
@@ -192,7 +190,7 @@ class LLMGenerator:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -211,7 +209,7 @@ class LLMGenerator:
# format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -222,8 +220,8 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
@@ -252,7 +250,7 @@ class LLMGenerator:
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
@@ -278,7 +276,7 @@ class LLMGenerator:
error_step = "generate conversation opener"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
@@ -286,20 +284,16 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_code(
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
):
if args.code_language == "python":
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": args.instruction,
"CODE_LANGUAGE": args.code_language,
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
},
remove_template_variables=False,
)
@@ -308,28 +302,28 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = args.model_config_data.completion_params
model_parameters = model_config.get("completion_params", {})
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": args.code_language, "error": ""}
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
)
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@@ -359,20 +353,20 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
def generate_structured_output(cls, tenant_id: str, instruction: str, model_config: dict):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=args.instruction),
UserPromptMessage(content=instruction),
]
model_parameters = args.model_config_data.completion_params
model_parameters = model_config.get("model_parameters", {})
try:
response: LLMResult = model_instance.invoke_llm(
@@ -396,17 +390,12 @@ class LLMGenerator:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@staticmethod
def instruction_modify_legacy(
tenant_id: str,
flow_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
ideal_output: str | None,
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
@@ -445,7 +434,7 @@ class LLMGenerator:
node_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
model_config: dict,
ideal_output: str | None,
workflow_service: WorkflowServiceInterface,
):
@@ -516,7 +505,7 @@ class LLMGenerator:
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: ModelConfig,
model_config: dict,
last_run: dict | None,
current: str | None,
error_message: str | None,
@@ -537,8 +526,8 @@ class LLMGenerator:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
match node_type:
case "llm" | "agent":
@@ -581,5 +570,7 @@ class LLMGenerator:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
logger.exception(
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
return {"error": f"An unexpected error occurred: {str(e)}"}

View File

@@ -441,13 +441,11 @@ DEFAULT_GENERATOR_SUMMARY_PROMPT = (
Requirements:
1. Write a concise summary in plain text
2. You must write in {language}. No language other than {language} should be used.
2. Use the same language as the input content
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
7. If there is not enough content to generate a meaningful summary,
return an empty string without any explanation or prompt
Output only the summary text. Start summarizing now:

View File

@@ -347,7 +347,7 @@ class BaseSession(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder[ReceiveRequestT, SendResultT](
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
request=validated_request,

View File

@@ -88,7 +88,7 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
DefaultParameterName.MAX_TOKENS: {
"label": {
"en_US": "Max Tokens",
"zh_Hans": "最大 Token 数",
"zh_Hans": "最大标记",
},
"type": "int",
"help": {

View File

@@ -92,10 +92,6 @@ def _build_llm_result_from_first_chunk(
Build a single `LLMResult` from the first returned chunk.
This is used for `stream=False` because the plugin side may still implement the response via a chunked stream.
Note:
This function always drains the `chunks` iterator after reading the first chunk to ensure any underlying
streaming resources are released (e.g., HTTP connections owned by the plugin runtime).
"""
content = ""
content_list: list[PromptMessageContentUnionTypes] = []
@@ -103,25 +99,18 @@ def _build_llm_result_from_first_chunk(
system_fingerprint: str | None = None
tools_calls: list[AssistantPromptMessage.ToolCall] = []
try:
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
first_chunk = next(chunks, None)
if first_chunk is not None:
if isinstance(first_chunk.delta.message.content, str):
content += first_chunk.delta.message.content
elif isinstance(first_chunk.delta.message.content, list):
content_list.extend(first_chunk.delta.message.content)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
if first_chunk.delta.message.tool_calls:
_increase_tool_call(first_chunk.delta.message.tool_calls, tools_calls)
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
finally:
try:
for _ in chunks:
pass
except Exception:
logger.debug("Failed to drain non-stream plugin chunk iterator.", exc_info=True)
usage = first_chunk.delta.usage or LLMUsage.empty_usage()
system_fingerprint = first_chunk.system_fingerprint
return LLMResult(
model=model,
@@ -294,7 +283,7 @@ class LargeLanguageModel(AIModel):
# TODO
raise self._transform_invoke_error(e)
if stream and not isinstance(result, LLMResult):
if stream and isinstance(result, Generator):
return self._invoke_result_generator(
model=model,
result=result,

View File

@@ -314,8 +314,6 @@ class ModelProviderFactory:
elif model_type == ModelType.TTS:
return TTSModel.model_validate(init_params)
raise ValueError(f"Unsupported model type: {model_type}")
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""
Get provider icon

View File

@@ -48,22 +48,12 @@ class BaseIndexProcessor(ABC):
@abstractmethod
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
Args:
tenant_id: Tenant ID
preview_texts: List of preview details to generate summaries for
summary_index_setting: Summary index configuration
doc_language: Optional document language to ensure summary is generated in the correct language
"""
raise NotImplementedError

View File

@@ -275,11 +275,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("Chunks is not a list")
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
@@ -302,15 +298,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(
tenant_id, preview.content, summary_index_setting, document_language=doc_language
)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
@@ -364,7 +356,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
document_language: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
@@ -375,8 +366,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
document_language: Optional document language (e.g., "Chinese", "English")
to ensure summary is generated in the correct language
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
@@ -392,22 +381,8 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("model_name and model_provider_name are required in summary_index_setting")
# Import default summary prompt
is_default_prompt = False
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
is_default_prompt = True
# Format prompt with document language only for default prompt
# Custom prompts are used as-is to avoid interfering with user-defined templates
# If document_language is provided, use it; otherwise, use "the same language as the input content"
# This is especially important for image-only chunks where text is empty or minimal
if is_default_prompt:
language_for_prompt = document_language or "the same language as the input content"
try:
summary_prompt = summary_prompt.format(language=language_for_prompt)
except KeyError:
# If default prompt doesn't have {language} placeholder, use it as-is
pass
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(

View File

@@ -358,11 +358,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
@@ -393,7 +389,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary
else:
@@ -402,7 +397,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
preview.summary = summary

View File

@@ -241,11 +241,7 @@ class QAIndexProcessor(BaseIndexProcessor):
}
def generate_summary_preview(
self,
tenant_id: str,
preview_texts: list[PreviewDetail],
summary_index_setting: dict,
doc_language: str | None = None,
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.

View File

@@ -35,7 +35,6 @@ class SchemaRegistry:
registry.load_all_versions()
cls._default_instance = registry
return cls._default_instance
return cls._default_instance

View File

@@ -189,13 +189,16 @@ class ToolManager:
raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
@@ -297,15 +300,18 @@ class ToolManager:
decrypted_credentials = refreshed_credentials.credentials
cache.delete()
return builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
)
elif provider_type == ToolProviderType.API:

View File

@@ -7,6 +7,11 @@ from core.workflow.nodes.base.entities import OutputVariableEntity
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@classmethod
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
"""

View File

@@ -23,8 +23,8 @@ class TriggerDebugEventBus:
"""
# LUA_SELECT: Atomic poll or register for event
# KEYS[1] = trigger_debug_inbox:{<tenant_id>}:<address_id>
# KEYS[2] = trigger_debug_waiting_pool:{<tenant_id>}:...
# KEYS[1] = trigger_debug_inbox:{tenant_id}:{address_id}
# KEYS[2] = trigger_debug_waiting_pool:{tenant_id}:...
# ARGV[1] = address_id
LUA_SELECT = (
"local v=redis.call('GET',KEYS[1]);"
@@ -35,7 +35,7 @@ class TriggerDebugEventBus:
)
# LUA_DISPATCH: Dispatch event to all waiting addresses
# KEYS[1] = trigger_debug_waiting_pool:{<tenant_id>}:...
# KEYS[1] = trigger_debug_waiting_pool:{tenant_id}:...
# ARGV[1] = tenant_id
# ARGV[2] = event_json
LUA_DISPATCH = (
@@ -43,7 +43,7 @@ class TriggerDebugEventBus:
"if #a==0 then return 0 end;"
"redis.call('DEL',KEYS[1]);"
"for i=1,#a do "
f"redis.call('SET','trigger_debug_inbox:{{'..ARGV[1]..'}}'..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
f"redis.call('SET','trigger_debug_inbox:'..ARGV[1]..':'..a[i],ARGV[2],'EX',{TRIGGER_DEBUG_EVENT_TTL});"
"end;"
"return #a"
)
@@ -108,7 +108,7 @@ class TriggerDebugEventBus:
Event object if available, None otherwise
"""
address_id: str = hashlib.sha256(f"{user_id}|{app_id}|{node_id}".encode()).hexdigest()
address: str = f"trigger_debug_inbox:{{{tenant_id}}}:{address_id}"
address: str = f"trigger_debug_inbox:{tenant_id}:{address_id}"
try:
event_data = redis_client.eval(

View File

@@ -42,7 +42,7 @@ def build_webhook_pool_key(tenant_id: str, app_id: str, node_id: str) -> str:
app_id: App ID
node_id: Node ID
"""
return f"{TriggerDebugPoolKey.WEBHOOK}:{{{tenant_id}}}:{app_id}:{node_id}"
return f"{TriggerDebugPoolKey.WEBHOOK}:{tenant_id}:{app_id}:{node_id}"
class PluginTriggerDebugEvent(BaseDebugEvent):
@@ -64,4 +64,4 @@ def build_plugin_pool_key(tenant_id: str, provider_id: str, subscription_id: str
provider_id: Provider ID
subscription_id: Subscription ID
"""
return f"{TriggerDebugPoolKey.PLUGIN}:{{{tenant_id}}}:{str(provider_id)}:{subscription_id}:{name}"
return f"{TriggerDebugPoolKey.PLUGIN}:{tenant_id}:{str(provider_id)}:{subscription_id}:{name}"

View File

@@ -1,24 +0,0 @@
from __future__ import annotations
import sys
from pydantic import TypeAdapter, with_config
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
@with_config(extra="allow")
class NodeConfigData(TypedDict):
type: str
@with_config(extra="allow")
class NodeConfigDict(TypedDict):
id: str
data: NodeConfigData
NodeConfigDictAdapter = TypeAdapter(NodeConfigDict)

View File

@@ -5,20 +5,15 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from pydantic import TypeAdapter
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict])
class NodeFactory(Protocol):
"""
@@ -28,7 +23,7 @@ class NodeFactory(Protocol):
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: NodeConfigDict) -> Node:
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data.
@@ -68,24 +63,28 @@ class Graph:
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]:
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, NodeConfigDict] = {}
node_configs_map: dict[str, dict[str, object]] = {}
for node_config in node_configs:
node_configs_map[node_config["id"]] = node_config
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, NodeConfigDict],
node_configs_map: Mapping[str, Mapping[str, object]],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
@@ -114,8 +113,10 @@ class Graph:
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid]["data"]
node_type = node_data["type"]
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
@@ -175,7 +176,7 @@ class Graph:
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, NodeConfigDict],
node_configs_map: dict[str, dict[str, object]],
node_factory: NodeFactory,
) -> dict[str, Node]:
"""
@@ -302,7 +303,7 @@ class Graph:
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = _ListNodeConfigDict.validate_python(node_configs)
node_configs = cast(list[dict[str, object]], node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")

View File

@@ -46,6 +46,7 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
@@ -89,7 +90,7 @@ class GraphEngine:
self._graph_execution.workflow_id = workflow_id
# === Execution Queues ===
self._ready_queue = self._graph_runtime_state.ready_queue
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()

View File

@@ -15,10 +15,10 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .path import Path
from .session import ResponseSession
@@ -75,7 +75,7 @@ class ResponseStreamCoordinator:
Ensures ordered streaming of responses based on upstream node outputs and constants.
"""
def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None:
def __init__(self, variable_pool: "VariablePool", graph: "Graph") -> None:
"""
Initialize coordinator with variable pool.

View File

@@ -10,10 +10,10 @@ from __future__ import annotations
from dataclasses import dataclass
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
from core.workflow.runtime.graph_runtime_state import NodeProtocol
@dataclass
@@ -29,26 +29,21 @@ class ResponseSession:
index: int = 0 # Current position in the template segments
@classmethod
def from_node(cls, node: NodeProtocol) -> ResponseSession:
def from_node(cls, node: Node) -> ResponseSession:
"""
Create a ResponseSession from a response-capable node.
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer,
but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides:
- `id: str`
- `get_streaming_template() -> Template`
Create a ResponseSession from an AnswerNode or EndNode.
Args:
node: Node from the materialized workflow graph.
node: Must be either an AnswerNode or EndNode instance
Returns:
ResponseSession configured with the node's streaming template
Raises:
TypeError: If node is not a supported response node type.
TypeError: If node is not an AnswerNode or EndNode
"""
if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode):
raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode")
raise TypeError
return cls(
node_id=node.id,
template=node.get_streaming_template(),

View File

@@ -192,33 +192,32 @@ class AgentNode(Node[AgentNodeData]):
result[parameter_name] = None
continue
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
case "mixed" | "constant":
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
except TypeError:
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
if not isinstance(agent_input.value, str):
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
else:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
case _:
raise AgentInputTypeError(agent_input.type)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
parameter_value = segment_group.log if for_log else segment_group.text
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
if not isinstance(agent_input.value, str):
parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@@ -375,13 +374,12 @@ class AgentNode(Node[AgentNodeData]):
result: dict[str, Any] = {}
for parameter_name in typed_node_data.agent_parameters:
input = typed_node_data.agent_parameters[parameter_name]
match input.type:
case "mixed" | "constant":
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
result = {node_id + "." + key: value for key, value in result.items()}

View File

@@ -115,7 +115,7 @@ class DefaultValue(BaseModel):
@model_validator(mode="after")
def validate_value_type(self) -> DefaultValue:
# Type validation configuration
type_validators: dict[DefaultValueType, dict[str, Any]] = {
type_validators = {
DefaultValueType.STRING: {
"type": str,
"converter": lambda x: x,

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Literal
from typing import Annotated, Literal, Self
from pydantic import AfterValidator, BaseModel
@@ -34,7 +34,7 @@ class CodeNodeData(BaseNodeData):
class Output(BaseModel):
type: Annotated[SegmentType, AfterValidator(_validate_type)]
children: dict[str, "CodeNodeData.Output"] | None = None
children: dict[str, Self] | None = None
class Dependency(BaseModel):
name: str

View File

@@ -69,13 +69,11 @@ class DatasourceNode(Node[DatasourceNodeData]):
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=datasource_type,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
@@ -270,18 +268,15 @@ class DatasourceNode(Node[DatasourceNodeData]):
if typed_node_data.datasource_parameters:
for parameter_name in typed_node_data.datasource_parameters:
input = typed_node_data.datasource_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
result[parameter_name] = input.value
case "constant":
pass
case None:
pass
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
@@ -311,107 +306,99 @@ class DatasourceNode(Node[DatasourceNodeData]):
variables: dict[str, Any] = {}
for message in message_stream:
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
if message.type in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
elif message.type == DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
)
elif message.type == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
else:
variables[variable_name] = variable_value
elif message.type == DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],

View File

@@ -2,7 +2,7 @@ import base64
import json
import secrets
import string
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from copy import deepcopy
from typing import Any, Literal
from urllib.parse import urlencode, urlparse
@@ -11,9 +11,9 @@ import httpx
from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.runtime import VariablePool
@@ -79,8 +79,8 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
max_retries: int = dify_config.SSRF_DEFAULT_MAX_RETRIES,
http_client: HttpClientProtocol | None = None,
file_manager: FileManagerProtocol | None = None,
http_client: HttpClientProtocol = ssrf_proxy,
file_manager: FileManagerProtocol = file_manager,
):
# If authorization API key is present, convert the API key using the variable pool
if node_data.authorization.type == "api-key":
@@ -107,8 +107,8 @@ class Executor:
self.data = None
self.json = None
self.max_retries = max_retries
self._http_client = http_client or ssrf_proxy
self._file_manager = file_manager or default_file_manager
self._http_client = http_client
self._file_manager = file_manager
# init template
self.variable_pool = variable_pool
@@ -336,7 +336,7 @@ class Executor:
"""
do http request depending on api bundle
"""
_METHOD_MAP: dict[str, Callable[..., httpx.Response]] = {
_METHOD_MAP = {
"get": self._http_client.get,
"head": self._http_client.head,
"post": self._http_client.post,
@@ -348,7 +348,7 @@ class Executor:
if method_lc not in _METHOD_MAP:
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
request_args: dict[str, Any] = {
request_args = {
"data": self.data,
"files": self.files,
"json": self.json,
@@ -361,13 +361,14 @@ class Executor:
}
# request_args = {k: v for k, v in request_args.items() if v is not None}
try:
response = _METHOD_MAP[method_lc](
response: httpx.Response = _METHOD_MAP[method_lc](
url=self.url,
**request_args,
max_retries=self.max_retries,
)
except (self._http_client.max_retries_exceeded_error, self._http_client.request_error) as e:
raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response
def invoke(self) -> Response:

View File

@@ -4,9 +4,8 @@ from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.file import File, FileTransferMethod
from core.file.file_manager import file_manager as default_file_manager
from core.helper.ssrf_proxy import ssrf_proxy
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -48,9 +47,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
http_client: HttpClientProtocol | None = None,
http_client: HttpClientProtocol = ssrf_proxy,
tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
file_manager: FileManagerProtocol | None = None,
file_manager: FileManagerProtocol = file_manager,
) -> None:
super().__init__(
id=id,
@@ -58,9 +57,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._http_client = http_client or ssrf_proxy
self._http_client = http_client
self._tool_file_manager_factory = tool_file_manager_factory
self._file_manager = file_manager or default_file_manager
self._file_manager = file_manager
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@@ -397,7 +397,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return outputs
# Check if all non-None outputs are lists
non_none_outputs: list[object] = [output for output in outputs if output is not None]
non_none_outputs = [output for output in outputs if output is not None]
if not non_none_outputs:
return outputs

View File

@@ -78,21 +78,12 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
# Try to get document language if document_id is available
doc_language = None
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document and document.doc_language:
doc_language = document.doc_language
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
doc_language=doc_language,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -324,7 +315,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
doc_language: str | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
@@ -336,7 +326,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
doc_language: Optional document language to ensure summary is generated in the correct language
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
@@ -376,7 +365,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary
@@ -386,7 +374,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
document_language=doc_language,
)
if summary:
preview_item["summary"] = summary

View File

@@ -303,34 +303,33 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
case "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
case _:
else:
reranking_model = None
weights = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
if node_data.multiple_retrieval_config.weights is None:
raise ValueError("weights is required")
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
@@ -454,74 +453,73 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
filters: list[Any] = []
metadata_condition = None
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
if node_data.metadata_filtering_mode == "disabled":
return None, None, usage
elif node_data.metadata_filtering_mode == "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
case _:
raise ValueError("Invalid metadata filtering mode")
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions

View File

@@ -196,13 +196,13 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
case "name":
return lambda x: x.filename or ""
case "type":
return lambda x: str(x.type)
return lambda x: x.type
case "extension":
return lambda x: x.extension or ""
case "mime_type":
return lambda x: x.mime_type or ""
case "transfer_method":
return lambda x: str(x.transfer_method)
return lambda x: x.transfer_method
case "url":
return lambda x: x.remote_url or ""
case "related_id":
@@ -276,6 +276,7 @@ def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Calla
def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]:
extract_func: Callable[[File], Any]
if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str):
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x))
@@ -283,8 +284,8 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str
extract_func = _get_file_extract_string_func(key=key)
return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x))
elif key == "size" and isinstance(value, str):
extract_number = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x))
extract_func = _get_file_extract_number_func(key=key)
return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x))
else:
raise InvalidKeyError(f"Invalid key: {key}")

View File

@@ -852,16 +852,18 @@ class LLMNode(Node[LLMNodeData]):
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
if isinstance(prompt_content, str):
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
@@ -871,12 +873,13 @@ class LLMNode(Node[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if isinstance(prompt_content, str):
if prompt_content_type == str:
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
@@ -1030,14 +1033,14 @@ class LLMNode(Node[LLMNodeData]):
if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
if prompt_template.edition_type == "jinja2":
enable_jinja = True
else:
if isinstance(prompt_template, list):
for prompt in prompt_template:
if prompt.edition_type == "jinja2":
enable_jinja = True
break
else:
if prompt_template.edition_type == "jinja2":
enable_jinja = True
if enable_jinja:
for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:

View File

@@ -1,4 +1,4 @@
from typing import Any, Protocol
from typing import Protocol
import httpx
@@ -12,17 +12,17 @@ class HttpClientProtocol(Protocol):
@property
def request_error(self) -> type[Exception]: ...
def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def get(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def head(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def post(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def put(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def delete(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ...
def patch(self, url: str, max_retries: int = ..., **kwargs: object) -> httpx.Response: ...
class FileManagerProtocol(Protocol):

View File

@@ -482,17 +482,16 @@ class ToolNode(Node[ToolNodeData]):
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
match input.type:
case "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
case "constant":
pass
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}

View File

@@ -6,13 +6,12 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, ClassVar, Protocol
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
@@ -104,33 +103,14 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
execution_type: NodeExecutionType
node_type: ClassVar[NodeType]
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
class EdgeProtocol(Protocol):
id: str
state: NodeState
tail: str
head: str
source_handle: str
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
@dataclass(slots=True)

View File

@@ -144,11 +144,11 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config["data"]
node_config = dict(workflow.get_node_config_by_id(node_id))
node_config_data = node_config.get("data", {})
# Get node type
node_type = NodeType(node_config_data["type"])
node_type = NodeType(node_config_data.get("type"))
# init graph init params and runtime state
graph_init_params = GraphInitParams(

View File

@@ -27,13 +27,10 @@ def init_app(app: DifyApp) -> None:
)
# Ensure route decorators are evaluated.
import controllers.console.init_validate as init_validate_module
import controllers.console.ping as ping_module
from controllers.console import remote_files, setup
from controllers.console import setup
_ = init_validate_module
_ = ping_module
_ = remote_files
_ = setup
router.include_router(console_router, prefix="/console/api")

View File

@@ -390,7 +390,8 @@ class ClickZettaVolumeStorage(BaseStorage):
"""
content = self.load_once(filename)
Path(target_filepath).write_bytes(content)
with Path(target_filepath).open("wb") as f:
f.write(content)
logger.debug("File %s downloaded from ClickZetta Volume to %s", filename, target_filepath)

View File

@@ -1,69 +1,36 @@
from __future__ import annotations
from flask_restx import Namespace, fields
from datetime import datetime
from libs.helper import TimestampField
from pydantic import BaseModel, ConfigDict, Field, field_validator
annotation_fields = {
"id": fields.String,
"question": fields.String,
"answer": fields.Raw(attribute="content"),
"hit_count": fields.Integer,
"created_at": TimestampField,
# 'account': fields.Nested(simple_account_fields, allow_null=True)
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def build_annotation_model(api_or_ns: Namespace):
"""Build the annotation model for the API or Namespace."""
return api_or_ns.model("Annotation", annotation_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
annotation_list_fields = {
"data": fields.List(fields.Nested(annotation_fields)),
}
annotation_hit_history_fields = {
"id": fields.String,
"source": fields.String,
"score": fields.Float,
"question": fields.String,
"created_at": TimestampField,
"match": fields.String(attribute="annotation_question"),
"response": fields.String(attribute="annotation_content"),
}
class Annotation(ResponseModel):
id: str
question: str | None = None
answer: str | None = Field(default=None, validation_alias="content")
hit_count: int | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationList(ResponseModel):
data: list[Annotation]
has_more: bool
limit: int
total: int
page: int
class AnnotationExportList(ResponseModel):
data: list[Annotation]
class AnnotationHitHistory(ResponseModel):
id: str
source: str | None = None
score: float | None = None
question: str | None = None
created_at: int | None = None
match: str | None = Field(default=None, validation_alias="annotation_question")
response: str | None = Field(default=None, validation_alias="annotation_content")
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AnnotationHitHistoryList(ResponseModel):
data: list[AnnotationHitHistory]
has_more: bool
limit: int
total: int
page: int
annotation_hit_history_list_fields = {
"data": fields.List(fields.Nested(annotation_hit_history_fields)),
}

View File

@@ -1,7 +1,4 @@
from __future__ import annotations
from flask_restx import fields
from pydantic import BaseModel, ConfigDict
from flask_restx import Namespace, fields
simple_end_user_fields = {
"id": fields.String,
@@ -11,18 +8,5 @@ simple_end_user_fields = {
}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleEndUser(ResponseModel):
id: str
type: str
is_anonymous: bool
session_id: str | None = None
def build_simple_end_user_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleEndUser", simple_end_user_fields)

View File

@@ -1,11 +1,6 @@
from __future__ import annotations
from flask_restx import Namespace, fields
from datetime import datetime
from flask_restx import fields
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
from core.file import helpers as file_helpers
from libs.helper import AvatarUrlField, TimestampField
simple_account_fields = {
"id": fields.String,
@@ -14,78 +9,36 @@ simple_account_fields = {
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
def build_simple_account_model(api_or_ns: Namespace):
return api_or_ns.model("SimpleAccount", simple_account_fields)
def _build_avatar_url(avatar: str | None) -> str | None:
if avatar is None:
return None
if avatar.startswith(("http://", "https://")):
return avatar
return file_helpers.get_signed_file_url(avatar)
account_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"is_password_set": fields.Boolean,
"interface_language": fields.String,
"interface_theme": fields.String,
"timezone": fields.String,
"last_login_at": TimestampField,
"last_login_ip": fields.String,
"created_at": TimestampField,
}
account_with_role_fields = {
"id": fields.String,
"name": fields.String,
"avatar": fields.String,
"avatar_url": AvatarUrlField,
"email": fields.String,
"last_login_at": TimestampField,
"last_active_at": TimestampField,
"created_at": TimestampField,
"role": fields.String,
"status": fields.String,
}
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class _AccountAvatar(ResponseModel):
avatar: str | None = None
@computed_field(return_type=str | None) # type: ignore[prop-decorator]
@property
def avatar_url(self) -> str | None:
return _build_avatar_url(self.avatar)
class Account(_AccountAvatar):
id: str
name: str
email: str
is_password_set: bool
interface_language: str | None = None
interface_theme: str | None = None
timezone: str | None = None
last_login_at: int | None = None
last_login_ip: str | None = None
created_at: int | None = None
@field_validator("last_login_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRole(_AccountAvatar):
id: str
name: str
email: str
last_login_at: int | None = None
last_active_at: int | None = None
created_at: int | None = None
role: str
status: str
@field_validator("last_login_at", "last_active_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class AccountWithRoleList(ResponseModel):
accounts: list[AccountWithRole]
account_with_role_list_fields = {"accounts": fields.List(fields.Nested(account_with_role_fields))}

View File

@@ -0,0 +1,45 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

Some files were not shown because too many files have changed in this diff Show More