Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
5659c8aa70 chore(deps): bump the llm group across 1 directory with 2 updates
Updates the requirements on [transformers](https://github.com/huggingface/transformers) and [weave](https://github.com/wandb/weave) to permit the latest version.

Updates `transformers` to 5.5.3
- [Release notes](https://github.com/huggingface/transformers/releases)
- [Commits](https://github.com/huggingface/transformers/compare/v5.3.0...v5.5.3)

Updates `weave` to 0.52.36
- [Release notes](https://github.com/wandb/weave/releases)
- [Changelog](https://github.com/wandb/weave/blob/master/dev_docs/RELEASE.md)
- [Commits](https://github.com/wandb/weave/compare/v0.52.16...v0.52.36)

---
updated-dependencies:
- dependency-name: transformers
  dependency-version: 5.5.3
  dependency-type: direct:production
  dependency-group: llm
- dependency-name: weave
  dependency-version: 0.52.36
  dependency-type: direct:production
  dependency-group: llm
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-04-13 02:54:51 +00:00
524 changed files with 7805 additions and 11384 deletions

View File

@@ -1,79 +0,0 @@
---
name: e2e-cucumber-playwright
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
---
# Dify E2E Cucumber + Playwright
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
## Scope
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
- Do not use this skill for backend test or API review tasks under `api/`.
## Read Order
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
2. Read only the files directly involved in the task:
- target `.feature` files under `e2e/features/`
- related step files under `e2e/features/step-definitions/`
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
## Local Rules
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
- Browser session behavior comes from `features/support/hooks.ts`:
- default: authenticated session with shared storage state
- `@unauthenticated`: clean browser context
- `@authenticated`: readability/selective-run tag only unless implementation changes
- `@fresh`: only for `e2e:full*` flows
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
## Workflow
1. Rebuild local context.
- Inspect the target feature area.
- Reuse an existing step when wording and behavior already match.
- Add a new step only for a genuinely new user action or assertion.
- Keep edits close to the current capability folder unless the step is broadly reusable.
2. Write behavior-first scenarios.
- Describe user-observable behavior, not DOM mechanics.
- Keep each scenario focused on one workflow or outcome.
- Keep scenarios independent and re-runnable.
3. Write step definitions in the local style.
- Keep one step to one user-visible action or one assertion.
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
- Scope locators to stable containers when the page has repeated elements.
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
4. Use Playwright in the local style.
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
- Use web-first `expect(...)` assertions.
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
5. Validate narrowly.
- Run the narrowest tagged scenario or flow that exercises the change.
- Run `pnpm -C e2e check`.
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
## Review Checklist
- Does the scenario describe behavior rather than implementation?
- Does it fit the current session model, tags, and `DifyWorld` usage?
- Should an existing step be reused instead of adding a new one?
- Are locators user-facing and assertions web-first?
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
- Does it document or implement behavior that differs from the real hooks or configuration?
Lead findings with correctness, flake risk, and architecture drift.
## References
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)

View File

@@ -1,4 +0,0 @@
interface:
display_name: "E2E Cucumber + Playwright"
short_description: "Write and review Dify E2E scenarios."
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."

View File

@@ -1,93 +0,0 @@
# Cucumber Best Practices For Dify E2E
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
Official sources:
- https://cucumber.io/docs/guides/10-minute-tutorial/
- https://cucumber.io/docs/cucumber/step-definitions/
- https://cucumber.io/docs/cucumber/cucumber-expressions/
## What Matters Most
### 1. Treat scenarios as executable specifications
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
Apply it like this:
- write what the user does and what should happen
- avoid UI-internal wording such as selector details, DOM structure, or component names
- keep language concrete enough that the scenario reads like living documentation
### 2. Keep scenarios focused
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
In Dify's suite, this means:
- one capability-focused scenario per feature path
- no long setup chains when existing bootstrap or reusable steps already cover them
- no hidden dependency on another scenario's side effects
### 3. Reuse steps, but only when behavior really matches
Good reuse reduces duplication. Bad reuse hides meaning.
Prefer reuse when:
- the user action is genuinely the same
- the expected outcome is genuinely the same
- the wording stays natural across features
Write a new step when:
- the behavior is materially different
- reusing the old wording would make the scenario misleading
- a supposedly generic step would become an implementation-detail wrapper
### 4. Prefer Cucumber Expressions
Use Cucumber Expressions for parameters unless regex is clearly necessary.
Common examples:
- `{string}` for labels, names, and visible text
- `{int}` for counts
- `{float}` for decimal values
- `{word}` only when the value is truly a single token
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
### 5. Keep step definitions thin and meaningful
Step definitions are glue between Gherkin and automation, not a second abstraction language.
For Dify:
- type `this` as `DifyWorld`
- use `async function`
- keep each step to one user-visible action or assertion
- rely on `DifyWorld` and existing support code for shared context
- avoid leaking cross-scenario state
### 6. Use tags intentionally
Tags should communicate run scope or session semantics, not become ad hoc metadata.
In Dify's current suite:
- capability tags group related scenarios
- `@unauthenticated` changes session behavior
- `@authenticated` is descriptive/selective, not a behavior switch by itself
- `@fresh` belongs to reset/full-install flows only
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
## Review Questions
- Does the scenario read like a real example of product behavior?
- Are the steps behavior-oriented instead of implementation-oriented?
- Is a reused step still truthful in this feature?
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
- Would a new reader understand the outcome without opening the step-definition file?

View File

@@ -1,96 +0,0 @@
# Playwright Best Practices For Dify E2E
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
Official sources:
- https://playwright.dev/docs/best-practices
- https://playwright.dev/docs/locators
- https://playwright.dev/docs/test-assertions
- https://playwright.dev/docs/browser-contexts
## What Matters Most
### 1. Keep scenarios isolated
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
Apply it like this:
- do not depend on another scenario having run first
- do not persist ad hoc scenario state outside `DifyWorld`
- do not couple ordinary scenarios to `@fresh` behavior
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
### 2. Prefer user-facing locators
Playwright recommends built-in locators that reflect what users perceive on the page.
Preferred order in this repository:
1. `getByRole`
2. `getByLabel`
3. `getByPlaceholder`
4. `getByText`
5. `getByTestId` when an explicit test contract is the most stable option
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
Also remember:
- repeated content usually needs scoping to a stable container
- exact text matching is often too brittle when role/name or label already exists
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
### 3. Use web-first assertions
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
Prefer:
- `await expect(page).toHaveURL(...)`
- `await expect(locator).toBeVisible()`
- `await expect(locator).toBeHidden()`
- `await expect(locator).toBeEnabled()`
- `await expect(locator).toHaveText(...)`
Avoid:
- `expect(await locator.isVisible()).toBe(true)`
- custom polling loops for DOM state
- `waitForTimeout` as synchronization
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
### 4. Let actions wait for actionability
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
Good pattern:
- assert a meaningful visible state when that is part of the behavior
- then click/fill/select via locator APIs
Bad pattern:
- stack arbitrary waits before every action
- wait on unstable implementation details instead of the visible state the user cares about
### 5. Match debugging to the current suite
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
- full-page screenshots
- page HTML
- console errors
- page errors
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
## Review Questions
- Would this locator survive DOM refactors that do not change user-visible behavior?
- Is this assertion using Playwright's retrying semantics?
- Is any explicit wait masking a real readiness problem?
- Does this code preserve per-scenario isolation?
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?

View File

@@ -1 +0,0 @@
../../.agents/skills/e2e-cucumber-playwright

100
.github/dependabot.yml vendored
View File

@@ -1,6 +1,106 @@
version: 2
updates:
- package-ecosystem: "pip"
directory: "/api"
open-pull-requests-limit: 10
schedule:
interval: "weekly"
groups:
flask:
patterns:
- "flask"
- "flask-*"
- "werkzeug"
- "gunicorn"
google:
patterns:
- "google-*"
- "googleapis-*"
opentelemetry:
patterns:
- "opentelemetry-*"
pydantic:
patterns:
- "pydantic"
- "pydantic-*"
llm:
patterns:
- "langfuse"
- "langsmith"
- "litellm"
- "mlflow*"
- "opik"
- "weave*"
- "arize*"
- "tiktoken"
- "transformers"
database:
patterns:
- "sqlalchemy"
- "psycopg2*"
- "psycogreen"
- "redis*"
- "alembic*"
storage:
patterns:
- "boto3*"
- "botocore*"
- "azure-*"
- "bce-*"
- "cos-python-*"
- "esdk-obs-*"
- "google-cloud-storage"
- "opendal"
- "oss2"
- "supabase*"
- "tos*"
vdb:
patterns:
- "alibabacloud*"
- "chromadb"
- "clickhouse-*"
- "clickzetta-*"
- "couchbase"
- "elasticsearch"
- "opensearch-py"
- "oracledb"
- "pgvect*"
- "pymilvus"
- "pymochow"
- "pyobvector"
- "qdrant-client"
- "intersystems-*"
- "tablestore"
- "tcvectordb"
- "tidb-vector"
- "upstash-*"
- "volcengine-*"
- "weaviate-*"
- "xinference-*"
- "mo-vector"
- "mysql-connector-*"
dev:
patterns:
- "coverage"
- "dotenv-linter"
- "faker"
- "lxml-stubs"
- "basedpyright"
- "ruff"
- "pytest*"
- "types-*"
- "boto3-stubs"
- "hypothesis"
- "pandas-stubs"
- "scipy-stubs"
- "import-linter"
- "celery-types"
- "mypy*"
- "pyrefly"
python-packages:
patterns:
- "*"
- package-ecosystem: "uv"
directory: "/api"
open-pull-requests-limit: 10

View File

@@ -18,7 +18,7 @@
## Checklist
- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs)
- [ ] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [ ] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [ ] I've updated the documentation accordingly.
- [ ] I ran `make lint && make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && pnpm exec vp staged` (frontend) to appease the lint gods

View File

@@ -6,7 +6,14 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- packages/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .npmrc
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}

View File

@@ -92,7 +92,6 @@ jobs:
vdb:
- 'api/core/rag/datasource/**'
- 'api/tests/integration_tests/vdb/**'
- 'api/providers/vdb/*/tests/**'
- '.github/workflows/vdb-tests.yml'
- '.github/workflows/expose_service_ports.sh'
- 'docker/.env.example'

View File

@@ -23,8 +23,8 @@ jobs:
days-before-issue-stale: 15
days-before-issue-close: 3
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-pr-message: "Closed due to inactivity. If you have any questions, you can reopen it."
stale-issue-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-pr-message: "Close due to it's no longer active, if you have any questions, you can reopen it."
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
any-of-labels: '🌚 invalid,🙋‍♂️ question,wont-fix,no-issue-activity,no-pr-activity,💪 enhancement,🤔 cant-reproduce,🙏 help wanted'
any-of-labels: 'duplicate,question,invalid,wontfix,no-issue-activity,no-pr-activity,enhancement,cant-reproduce,help-wanted'

View File

@@ -89,7 +89,7 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -81,12 +81,12 @@ jobs:
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/providers/vdb/vdb-chroma/tests/integration_tests \
api/providers/vdb/vdb-pgvector/tests/integration_tests \
api/providers/vdb/vdb-qdrant/tests/integration_tests \
api/providers/vdb/vdb-weaviate/tests/integration_tests
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate

View File

@@ -69,6 +69,8 @@ ignore = [
"FURB152", # math-constant
"UP007", # non-pep604-annotation
"UP032", # f-string
"UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters
"B006", # mutable-argument-default
"B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg
@@ -82,6 +84,7 @@ ignore = [
"SIM102", # collapsible-if
"SIM103", # needless-bool
"SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
@@ -90,16 +93,29 @@ ignore = [
]
[lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F811", # redefined-while-unused
]
"configs/*" = [
"N802", # invalid-function-name
]
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -21,9 +21,8 @@ RUN apt-get update \
# for building gmpy2
libmpfr-dev libmpc-dev
# Install Python dependencies (workspace members under providers/vdb/)
# Install Python dependencies
COPY pyproject.toml uv.lock ./
COPY providers ./providers
RUN uv sync --locked --no-dev
# production stage

View File

@@ -341,10 +341,11 @@ def add_qdrant_index(field: str):
click.echo(click.style("No dataset collection bindings found.", fg="red"))
return
import qdrant_client
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import PayloadSchemaType
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
for binding in bindings:
if dify_config.QDRANT_URL is None:
raise ValueError("Qdrant URL is required.")

View File

@@ -160,16 +160,6 @@ class DatabaseConfig(BaseSettings):
default="",
)
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
description=(
"PostgreSQL session timezone override injected via startup options."
" Default is 'UTC' for out-of-the-box consistency."
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
" together with a database-side default timezone."
),
default="UTC",
)
@computed_field # type: ignore[prop-decorator]
@property
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
@@ -237,13 +227,12 @@ class DatabaseConfig(BaseSettings):
connect_args: dict[str, str] = {}
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
merged_options = options.strip()
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
if session_timezone_override:
timezone_opt = f"-c timezone={session_timezone_override}"
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
if merged_options:
connect_args = {"options": merged_options}
timezone_opt = "-c timezone=UTC"
if options:
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
result: SQLAlchemyEngineOptionsDict = {
"pool_size": self.SQLALCHEMY_POOL_SIZE,

View File

@@ -1,3 +1,4 @@
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
from pydantic import Field
from pydantic_settings import BaseSettings
@@ -41,17 +42,17 @@ class HologresConfig(BaseSettings):
default="public",
)
HOLOGRES_TOKENIZER: str = Field(
HOLOGRES_TOKENIZER: TokenizerType = Field(
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
default="jieba",
)
HOLOGRES_DISTANCE_METHOD: str = Field(
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
default="Cosine",
)
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
default="rabitq",
)

View File

@@ -1,7 +1,5 @@
"""Configuration for InterSystems IRIS vector database."""
from typing import Any
from pydantic import Field, PositiveInt, model_validator
from pydantic_settings import BaseSettings
@@ -66,7 +64,7 @@ class IrisVectorConfig(BaseSettings):
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
def validate_config(cls, values: dict) -> dict:
"""Validate IRIS configuration values.
Args:

View File

@@ -1,5 +1,4 @@
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
@@ -24,9 +23,9 @@ class ConversationRenamePayload(BaseModel):
class MessageListQuery(BaseModel):
conversation_id: UUIDStrOrEmpty = Field(description="Conversation UUID")
first_id: UUIDStrOrEmpty | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
conversation_id: UUIDStrOrEmpty
first_id: UUIDStrOrEmpty | None = None
limit: int = Field(default=20, ge=1, le=100)
class MessageFeedbackPayload(BaseModel):
@@ -70,35 +69,11 @@ class WorkflowUpdatePayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
# --- Dataset schemas ---
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class MetadataUpdatePayload(BaseModel):
name: str
# --- Audio schemas ---
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None

View File

@@ -1,16 +1,12 @@
from datetime import datetime
import flask_restx
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from pydantic import field_validator
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.enums import ApiTokenType
@@ -20,31 +16,21 @@ from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
api_key_fields = {
"id": fields.String,
"type": fields.String,
"token": fields.String,
"last_used_at": TimestampField,
"created_at": TimestampField,
}
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
api_key_item_model = console_ns.model("ApiKeyItem", api_key_fields)
api_key_list = {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
class ApiKeyItem(ResponseModel):
id: str
type: str
token: str
last_used_at: int | None = None
created_at: int | None = None
@field_validator("last_used_at", "created_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
class ApiKeyList(ResponseModel):
data: list[ApiKeyItem]
register_schema_models(console_ns, ApiKeyItem, ApiKeyList)
api_key_list_model = console_ns.model(
"ApiKeyList", {"data": fields.List(fields.Nested(api_key_item_model), attribute="items")}
)
def _get_resource(resource_id, tenant_id, resource_model):
@@ -68,6 +54,7 @@ class BaseApiKeyListResource(Resource):
token_prefix: str | None = None
max_keys = 10
@marshal_with(api_key_list_model)
def get(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
resource_id = str(resource_id)
@@ -79,8 +66,9 @@ class BaseApiKeyListResource(Resource):
ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id
)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return {"items": keys}
@marshal_with(api_key_item_model)
@edit_permission_required
def post(self, resource_id):
assert self.resource_id_field is not None, "resource_id_field must be set"
@@ -112,7 +100,7 @@ class BaseApiKeyListResource(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 201
return api_token, 201
class BaseApiKeyResource(Resource):
@@ -159,7 +147,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_app_api_keys")
@console_ns.doc(description="Get all API keys for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for an app"""
return super().get(resource_id)
@@ -167,7 +155,7 @@ class AppApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_app_api_key")
@console_ns.doc(description="Create a new API key for an app")
@console_ns.doc(params={"resource_id": "App ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for an app"""
@@ -199,7 +187,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get all API keys for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "Success", api_key_list_model)
def get(self, resource_id): # type: ignore
"""Get all API keys for a dataset"""
return super().get(resource_id)
@@ -207,7 +195,7 @@ class DatasetApiKeyListResource(BaseApiKeyListResource):
@console_ns.doc("create_dataset_api_key")
@console_ns.doc(description="Create a new API key for a dataset")
@console_ns.doc(params={"resource_id": "Dataset ID"})
@console_ns.response(201, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(201, "API key created successfully", api_key_item_model)
@console_ns.response(400, "Maximum keys exceeded")
def post(self, resource_id): # type: ignore
"""Create a new API key for a dataset"""

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateQuery(BaseModel):
@@ -35,10 +35,5 @@ class AdvancedPromptTemplateList(Resource):
@account_initialization_required
def get(self):
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
prompt_args: AdvancedPromptTemplateArgs = {
"app_mode": args.app_mode,
"model_mode": args.model_mode,
"model_name": args.model_name,
"has_context": args.has_context,
}
return AdvancedPromptTemplateService.get_prompt(prompt_args)
return AdvancedPromptTemplateService.get_prompt(args.model_dump())

View File

@@ -25,13 +25,7 @@ from fields.annotation_fields import (
)
from libs.helper import uuid_value
from libs.login import login_required
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
UpdateAnnotationArgs,
UpdateAnnotationSettingArgs,
UpsertAnnotationArgs,
)
from services.annotation_service import AppAnnotationService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -126,12 +120,7 @@ class AnnotationReplyActionApi(Resource):
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
enable_args: EnableAnnotationArgs = {
"score_threshold": args.score_threshold,
"embedding_provider_name": args.embedding_provider_name,
"embedding_model_name": args.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_id)
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -172,8 +161,7 @@ class AppAnnotationSettingUpdateApi(Resource):
args = AnnotationSettingUpdatePayload.model_validate(console_ns.payload)
setting_args: UpdateAnnotationSettingArgs = {"score_threshold": args.score_threshold}
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, setting_args)
result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args.model_dump())
return result, 200
@@ -249,16 +237,8 @@ class AnnotationApi(Resource):
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
upsert_args: UpsertAnnotationArgs = {}
if args.answer is not None:
upsert_args["answer"] = args.answer
if args.content is not None:
upsert_args["content"] = args.content
if args.message_id is not None:
upsert_args["message_id"] = args.message_id
if args.question is not None:
upsert_args["question"] = args.question
annotation = AppAnnotationService.up_insert_app_annotation_from_message(upsert_args, app_id)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required
@@ -335,12 +315,9 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
args = UpdateAnnotationPayload.model_validate(console_ns.payload)
update_args: UpdateAnnotationArgs = {}
if args.answer is not None:
update_args["answer"] = args.answer
if args.question is not None:
update_args["question"] = args.question
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_id, annotation_id)
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
@setup_required

View File

@@ -1,8 +1,7 @@
from flask_restx import Resource
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from controllers.common.schema import register_schema_models
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
account_initialization_required,
@@ -11,15 +10,35 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.app_fields import (
app_import_check_dependencies_fields,
app_import_fields,
leaked_dependency_fields,
)
from libs.login import current_account_with_tenant, login_required
from models.model import App
from services.app_dsl_service import AppDslService, Import
from services.app_dsl_service import AppDslService
from services.enterprise.enterprise_service import EnterpriseService
from services.entities.dsl_entities import CheckDependenciesResult, ImportStatus
from services.entities.dsl_entities import ImportStatus
from services.feature_service import FeatureService
from .. import console_ns
# Register models for flask_restx to avoid dict type issues in Swagger
# Register base model first
leaked_dependency_model = console_ns.model("LeakedDependency", leaked_dependency_fields)
app_import_model = console_ns.model("AppImport", app_import_fields)
# For nested models, need to replace nested dict with registered model
app_import_check_dependencies_fields_copy = app_import_check_dependencies_fields.copy()
app_import_check_dependencies_fields_copy["leaked_dependencies"] = fields.List(fields.Nested(leaked_dependency_model))
app_import_check_dependencies_model = console_ns.model(
"AppImportCheckDependencies", app_import_check_dependencies_fields_copy
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppImportPayload(BaseModel):
mode: str = Field(..., description="Import mode")
@@ -33,18 +52,18 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
register_schema_models(console_ns, AppImportPayload, Import, CheckDependenciesResult)
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@console_ns.route("/apps/imports")
class AppImportApi(Resource):
@console_ns.expect(console_ns.models[AppImportPayload.__name__])
@console_ns.response(200, "Import completed", console_ns.models[Import.__name__])
@console_ns.response(202, "Import pending confirmation", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self):
@@ -85,11 +104,10 @@ class AppImportApi(Resource):
@console_ns.route("/apps/imports/<string:import_id>/confirm")
class AppImportConfirmApi(Resource):
@console_ns.response(200, "Import confirmed", console_ns.models[Import.__name__])
@console_ns.response(400, "Import failed", console_ns.models[Import.__name__])
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@edit_permission_required
def post(self, import_id):
# Check user role first
@@ -110,11 +128,11 @@ class AppImportConfirmApi(Resource):
@console_ns.route("/apps/imports/<string:app_id>/check-dependencies")
class AppImportCheckDependenciesApi(Resource):
@console_ns.response(200, "Dependencies checked", console_ns.models[CheckDependenciesResult.__name__])
@setup_required
@login_required
@get_app_model
@account_initialization_required
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:

View File

@@ -1,68 +1,39 @@
import json
from datetime import datetime
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.app_fields import app_server_fields
from libs.login import current_account_with_tenant, login_required
from models.enums import AppMCPServerStatus
from models.model import AppMCPServer
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
def _to_timestamp(value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return int(value.timestamp())
return value
# Register model for flask_restx to avoid dict type issues in Swagger
app_server_model = console_ns.model("AppServer", app_server_fields)
class MCPServerCreatePayload(BaseModel):
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
parameters: dict = Field(..., description="Server parameters configuration")
class MCPServerUpdatePayload(BaseModel):
id: str = Field(..., description="Server ID")
description: str | None = Field(default=None, description="Server description")
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
parameters: dict = Field(..., description="Server parameters configuration")
status: str | None = Field(default=None, description="Server status")
class AppMCPServerResponse(ResponseModel):
id: str
name: str
server_code: str
description: str
status: str
parameters: dict[str, Any] | list[Any] | str
created_at: int | None = None
updated_at: int | None = None
@field_validator("parameters", mode="before")
@classmethod
def _parse_json_string(cls, value: Any) -> Any:
if isinstance(value, str):
try:
return json.loads(value)
except (json.JSONDecodeError, TypeError):
return value
return value
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
return _to_timestamp(value)
register_schema_models(console_ns, MCPServerCreatePayload, MCPServerUpdatePayload, AppMCPServerResponse)
for model in (MCPServerCreatePayload, MCPServerUpdatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/apps/<uuid:app_id>/server")
@@ -70,27 +41,27 @@ class AppMCPServerController(Resource):
@console_ns.doc("get_app_mcp_server")
@console_ns.doc(description="Get MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Server configuration", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server configuration retrieved successfully", app_server_model)
@login_required
@account_initialization_required
@setup_required
@get_app_model
@marshal_with(app_server_model)
def get(self, app_model):
server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.app_id == app_model.id).limit(1))
if server is None:
return {}
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.doc("create_app_mcp_server")
@console_ns.doc(description="Create MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerCreatePayload.__name__])
@console_ns.response(200, "Server created", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(201, "MCP server configuration created successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@account_initialization_required
@get_app_model
@login_required
@setup_required
@marshal_with(app_server_model)
@edit_permission_required
def post(self, app_model):
_, current_tenant_id = current_account_with_tenant()
@@ -111,19 +82,20 @@ class AppMCPServerController(Resource):
)
db.session.add(server)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.doc("update_app_mcp_server")
@console_ns.doc(description="Update MCP server configuration for an application")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MCPServerUpdatePayload.__name__])
@console_ns.response(200, "Server updated", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server configuration updated successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@get_app_model
@login_required
@setup_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def put(self, app_model):
payload = MCPServerUpdatePayload.model_validate(console_ns.payload or {})
@@ -146,7 +118,7 @@ class AppMCPServerController(Resource):
except ValueError:
raise ValueError("Invalid status")
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server
@console_ns.route("/apps/<uuid:server_id>/server/refresh")
@@ -154,12 +126,13 @@ class AppMCPServerRefreshController(Resource):
@console_ns.doc("refresh_app_mcp_server")
@console_ns.doc(description="Refresh MCP server configuration and regenerate server code")
@console_ns.doc(params={"server_id": "Server ID"})
@console_ns.response(200, "Server refreshed", console_ns.models[AppMCPServerResponse.__name__])
@console_ns.response(200, "MCP server refreshed successfully", app_server_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "Server not found")
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_server_model)
@edit_permission_required
def get(self, server_id):
_, current_tenant_id = current_account_with_tenant()
@@ -172,4 +145,4 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return AppMCPServerResponse.model_validate(server, from_attributes=True).model_dump(mode="json")
return server

View File

@@ -1,12 +1,11 @@
from typing import Literal
from flask_restx import Resource
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@@ -16,11 +15,13 @@ from controllers.console.wraps import (
setup_required,
)
from extensions.ext_database import db
from fields.base import ResponseModel
from fields.app_fields import app_site_fields
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import Site
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class AppSiteUpdatePayload(BaseModel):
title: str | None = Field(default=None)
@@ -48,26 +49,13 @@ class AppSiteUpdatePayload(BaseModel):
return supported_language(value)
class AppSiteResponse(ResponseModel):
app_id: str
access_token: str | None = Field(default=None, validation_alias="code")
code: str | None = None
title: str
icon: str | None = None
icon_background: str | None = None
description: str | None = None
default_language: str
customize_domain: str | None = None
copyright: str | None = None
privacy_policy: str | None = None
custom_disclaimer: str | None = None
customize_token_strategy: str
prompt_public: bool
show_workflow_steps: bool
use_icon_as_answer_icon: bool
console_ns.schema_model(
AppSiteUpdatePayload.__name__,
AppSiteUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
register_schema_models(console_ns, AppSiteUpdatePayload, AppSiteResponse)
# Register model for flask_restx to avoid dict type issues in Swagger
app_site_model = console_ns.model("AppSite", app_site_fields)
@console_ns.route("/apps/<uuid:app_id>/site")
@@ -76,7 +64,7 @@ class AppSite(Resource):
@console_ns.doc(description="Update application site configuration")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[AppSiteUpdatePayload.__name__])
@console_ns.response(200, "Site configuration updated successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(200, "Site configuration updated successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions")
@console_ns.response(404, "App not found")
@setup_required
@@ -84,6 +72,7 @@ class AppSite(Resource):
@edit_permission_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
args = AppSiteUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
@@ -117,7 +106,7 @@ class AppSite(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
return site
@console_ns.route("/apps/<uuid:app_id>/site/access-token-reset")
@@ -125,7 +114,7 @@ class AppSiteAccessTokenReset(Resource):
@console_ns.doc("reset_app_site_access_token")
@console_ns.doc(description="Reset access token for application site")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Access token reset successfully", console_ns.models[AppSiteResponse.__name__])
@console_ns.response(200, "Access token reset successfully", app_site_model)
@console_ns.response(403, "Insufficient permissions (admin/owner required)")
@console_ns.response(404, "App or site not found")
@setup_required
@@ -133,6 +122,7 @@ class AppSiteAccessTokenReset(Resource):
@is_admin_or_owner_required
@account_initialization_required
@get_app_model
@marshal_with(app_site_model)
def post(self, app_model):
current_user, _ = current_account_with_tenant()
site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1))
@@ -145,4 +135,4 @@ class AppSiteAccessTokenReset(Resource):
site.updated_at = naive_utc_now()
db.session.commit()
return AppSiteResponse.model_validate(site, from_attributes=True).model_dump(mode="json")
return site

View File

@@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,

View File

@@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
from services.workflow_run_service import WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
@@ -214,11 +214,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
Get advanced chat app workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING if not specified
triggered_from = (
@@ -360,11 +356,7 @@ class WorkflowRunListApi(Resource):
Get workflow run list
"""
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
args: WorkflowRunListArgs = {"limit": args_model.limit}
if args_model.last_id is not None:
args["last_id"] = args_model.last_id
if args_model.status is not None:
args["status"] = args_model.status
args = args_model.model_dump(exclude_none=True)
# Default to DEBUGGING for workflow if not specified (backward compatibility)
triggered_from = (

View File

@@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
# Get webhook trigger for this app and node
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger)
@@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with sessionmaker(db.engine).begin() as session:
# Get all triggers for this app using select API
triggers = (
session.execute(

View File

@@ -1,9 +1,8 @@
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from constants.languages import supported_language
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
@@ -12,6 +11,8 @@ from libs.helper import EmailStr, timezone
from models import AccountStatus
from services.account_service import RegisterService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ActivateCheckQuery(BaseModel):
workspace_id: str | None = Field(default=None)
@@ -38,16 +39,8 @@ class ActivatePayload(BaseModel):
return timezone(value)
class ActivationCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether token is valid")
data: dict | None = Field(default=None, description="Activation data if valid")
class ActivationResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(console_ns, ActivateCheckQuery, ActivatePayload, ActivationCheckResponse, ActivationResponse)
for model in (ActivateCheckQuery, ActivatePayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/activate/check")
@@ -58,7 +51,13 @@ class ActivateCheckApi(Resource):
@console_ns.response(
200,
"Success",
console_ns.models[ActivationCheckResponse.__name__],
console_ns.model(
"ActivationCheckResponse",
{
"is_valid": fields.Boolean(description="Whether token is valid"),
"data": fields.Raw(description="Activation data if valid"),
},
),
)
def get(self):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
@@ -96,7 +95,12 @@ class ActivateApi(Resource):
@console_ns.response(
200,
"Account activated successfully",
console_ns.models[ActivationResponse.__name__],
console_ns.model(
"ActivationResponse",
{
"result": fields.String(description="Operation result"),
},
),
)
@console_ns.response(400, "Already activated or invalid token")
def post(self):

View File

@@ -1,10 +1,7 @@
import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@@ -45,13 +42,12 @@ from libs.token import (
)
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
from services.billing_service import BillingService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.entities.auth_entities import LoginPayloadBase
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.feature_service import FeatureService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@@ -95,12 +91,10 @@ class LoginApi(Resource):
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
@@ -116,20 +110,14 @@ class LoginApi(Resource):
invitee_email = data.get("email") if data else None
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
_log_console_login_failure(
email=normalized_email,
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
)
raise InvalidEmailError()
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
@@ -252,27 +240,20 @@ class EmailCodeLoginApi(Resource):
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != args.code:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
AccountService.revoke_email_code_login_token(args.token)
try:
account = _get_account_with_case_fallback(original_email)
except Unauthorized as exc:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
if account:
tenants = TenantService.get_join_tenants(account)
@@ -298,7 +279,6 @@ class EmailCodeLoginApi(Resource):
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()
except AccountRegisterError:
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
raise AccountInFreezeError()
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
@@ -356,12 +336,3 @@ def _authenticate_account_with_case_fallback(
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Console login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@@ -11,7 +11,10 @@ import services
from configs import dify_config
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from controllers.console.apikey import ApiKeyItem, ApiKeyList
from controllers.console.apikey import (
api_key_item_model,
api_key_list_model,
)
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from controllers.console.wraps import (
@@ -782,23 +785,23 @@ class DatasetApiKeyApi(Resource):
@console_ns.doc("get_dataset_api_keys")
@console_ns.doc(description="Get dataset API keys")
@console_ns.response(200, "API keys retrieved successfully", console_ns.models[ApiKeyList.__name__])
@console_ns.response(200, "API keys retrieved successfully", api_key_list_model)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_key_list_model)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return ApiKeyList.model_validate({"data": keys}, from_attributes=True).model_dump(mode="json")
return {"items": keys}
@console_ns.response(200, "API key created successfully", console_ns.models[ApiKeyItem.__name__])
@console_ns.response(400, "Maximum keys exceeded")
@setup_required
@login_required
@is_admin_or_owner_required
@account_initialization_required
@marshal_with(api_key_item_model)
def post(self):
_, current_tenant_id = current_account_with_tenant()
@@ -825,7 +828,7 @@ class DatasetApiKeyApi(Resource):
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return ApiKeyItem.model_validate(api_token, from_attributes=True).model_dump(mode="json"), 200
return api_token, 200
@console_ns.route("/datasets/api-keys/<uuid:api_key_id>")

View File

@@ -4,6 +4,7 @@ from argparse import ArgumentTypeError
from collections.abc import Sequence
from contextlib import ExitStack
from typing import Any, Literal, cast
from uuid import UUID
import sqlalchemy as sa
from flask import request, send_file
@@ -15,7 +16,6 @@ from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
from core.errors.error import (
@@ -71,6 +71,9 @@ from ..wraps import (
logger = logging.getLogger(__name__)
# NOTE: Keep constants near the top of the module for discoverability.
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
# Register models for flask_restx to avoid dict type issues in Swagger
dataset_model = get_or_create_model("Dataset", dataset_fields)
@@ -107,6 +110,12 @@ class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")

View File

@@ -10,7 +10,6 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
@@ -83,6 +82,14 @@ class BatchImportPayload(BaseModel):
upload_file_id: str
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkUpdatePayload(BaseModel):
content: str
class ChildChunkBatchUpdatePayload(BaseModel):
chunks: list[ChildChunkUpdateArgs]

View File

@@ -1,9 +1,9 @@
from typing import Literal
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
@@ -18,6 +18,11 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_models(
console_ns, MetadataArgs, MetadataOperationData, MetadataUpdatePayload, DocumentMetadataOperation, MetadataDetail
)

View File

@@ -1,4 +1,3 @@
from collections.abc import Mapping
from typing import TypedDict
from flask import request
@@ -14,14 +13,6 @@ from services.billing_service import BillingService
_FALLBACK_LANG = "en-US"
class NotificationLangContent(TypedDict, total=False):
lang: str
title: str
subtitle: str
body: str
titlePicUrl: str
class NotificationItemDict(TypedDict):
notification_id: str | None
frequency: str | None
@@ -37,11 +28,9 @@ class NotificationResponseDict(TypedDict):
notifications: list[NotificationItemDict]
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return (
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
)
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
class DismissNotificationPayload(BaseModel):
@@ -82,7 +71,7 @@ class NotificationApi(Resource):
notifications: list[NotificationItemDict] = []
for notification in result.get("notifications") or []:
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
contents: dict = notification.get("contents") or {}
lang_content = _pick_lang_content(contents, lang)
item: NotificationItemDict = {
"notification_id": notification.get("notificationId"),

View File

@@ -35,24 +35,22 @@ def plugin_permission_required(
return view(*args, **kwargs)
if install_required:
match permission.install_permission:
case TenantPluginPermission.InstallPermission.NOBODY:
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
raise Forbidden()
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.InstallPermission.EVERYONE:
pass
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
pass
if debug_required:
match permission.debug_permission:
case TenantPluginPermission.DebugPermission.NOBODY:
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
raise Forbidden()
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.ADMINS:
if not user.is_admin_or_owner:
raise Forbidden()
case TenantPluginPermission.DebugPermission.EVERYONE:
pass
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
pass
return view(*args, **kwargs)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Literal
from typing import Literal
import pytz
from flask import request
@@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
register_schema_models(console_ns, AccountResponse)
def _serialize_account(account) -> dict[str, Any]:
def _serialize_account(account) -> dict:
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")

View File

@@ -20,7 +20,7 @@ from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
from services.operation_service import OperationService, UtmInfo
from services.operation_service import OperationService
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
@@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: UtmInfo = json.loads(utm_info)
utm_info_dict: dict = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)

View File

@@ -94,9 +94,10 @@ def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
def plugin_data[**P, R](
view: Callable[P, R] | None = None,
*,
payload_type: type[BaseModel],
) -> Callable[[Callable[P, R]], Callable[P, R]]:
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -115,4 +116,7 @@ def plugin_data[**P, R](
return decorated_view
return decorator
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -2,7 +2,7 @@ from typing import Any, Union
from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity, VariableEntityType
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@@ -158,20 +158,14 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
"""Convert raw user input form to VariableEntity objects"""
return [self._create_variable_entity(item) for item in raw_form]
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
def _create_variable_entity(self, item: dict) -> VariableEntity:
"""Create a single VariableEntity from raw form item"""
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
try:
variable_type = VariableEntityType(variable_type_raw)
except ValueError as e:
raise MCPRequestError(
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
) from e
variable = item[variable_type_raw]
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
return VariableEntity(
type=variable_type,
@@ -184,7 +178,7 @@ class MCPAppApi(Resource):
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
"""Parse and validate MCP request"""
try:
return mcp_types.ClientRequest.model_validate(args)

View File

@@ -12,12 +12,7 @@ from controllers.service_api.wraps import validate_app_token
from extensions.ext_redis import redis_client
from fields.annotation_fields import Annotation, AnnotationList
from models.model import App
from services.annotation_service import (
AppAnnotationService,
EnableAnnotationArgs,
InsertAnnotationArgs,
UpdateAnnotationArgs,
)
from services.annotation_service import AppAnnotationService
class AnnotationCreatePayload(BaseModel):
@@ -51,15 +46,10 @@ class AnnotationReplyActionApi(Resource):
@validate_app_token
def post(self, app_model: App, action: Literal["enable", "disable"]):
"""Enable or disable annotation reply feature."""
payload = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {})
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
match action:
case "enable":
enable_args: EnableAnnotationArgs = {
"score_threshold": payload.score_threshold,
"embedding_provider_name": payload.embedding_provider_name,
"embedding_model_name": payload.embedding_model_name,
}
result = AppAnnotationService.enable_app_annotation(enable_args, app_model.id)
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_model.id)
return result, 200
@@ -145,9 +135,8 @@ class AnnotationListApi(Resource):
@validate_app_token
def post(self, app_model: App):
"""Create a new annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
insert_args: InsertAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.insert_app_annotation_directly(insert_args, app_model.id)
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json"), HTTPStatus.CREATED
@@ -175,9 +164,8 @@ class AnnotationUpdateDeleteApi(Resource):
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")

View File

@@ -3,10 +3,10 @@ import logging
from flask import request
from flask_restx import Resource
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.controller_schemas import TextToAudioPayload
from controllers.common.schema import register_schema_model
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import (
@@ -86,6 +86,13 @@ class AudioApi(Resource):
raise InternalServerError()
class TextToAudioPayload(BaseModel):
message_id: str | None = Field(default=None, description="Message ID")
voice: str | None = Field(default=None, description="Voice to use for TTS")
text: str | None = Field(default=None, description="Text to convert to audio")
streaming: bool | None = Field(default=None, description="Enable streaming response")
register_schema_model(service_api_ns, TextToAudioPayload)

View File

@@ -10,7 +10,6 @@ from sqlalchemy import desc, func, select
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.controller_schemas import DocumentBatchDownloadZipPayload
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@@ -101,6 +100,15 @@ class DocumentListQuery(BaseModel):
status: str | None = Field(default=None, description="Document status filter")
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading uploaded documents as a ZIP archive."""
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
register_enum_models(service_api_ns, RetrievalMethod)
register_schema_models(

View File

@@ -2,9 +2,9 @@ from typing import Literal
from flask_login import current_user
from flask_restx import marshal
from pydantic import BaseModel
from werkzeug.exceptions import NotFound
from controllers.common.controller_schemas import MetadataUpdatePayload
from controllers.common.schema import register_schema_model, register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
@@ -18,6 +18,11 @@ from services.entities.knowledge_entities.knowledge_entities import (
)
from services.metadata_service import MetadataService
class MetadataUpdatePayload(BaseModel):
name: str
register_schema_model(service_api_ns, MetadataUpdatePayload)
register_schema_models(
service_api_ns,

View File

@@ -8,7 +8,6 @@ from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.controller_schemas import ChildChunkCreatePayload, ChildChunkUpdatePayload
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
@@ -33,25 +32,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict[str, str | None] = {}
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result: list[dict[str, Any]] = []
result = []
for segment in segments:
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
@@ -70,12 +69,20 @@ class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,

View File

@@ -5,7 +5,6 @@ Web App Human Input Form APIs.
import json
import logging
from datetime import datetime
from typing import Any, NotRequired, TypedDict
from flask import Response, request
from flask_restx import Resource
@@ -59,19 +58,10 @@ def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
class FormDefinitionPayload(TypedDict):
form_content: Any
inputs: Any
resolved_default_values: dict[str, str]
user_actions: Any
expiration_time: int
site: NotRequired[dict]
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload: FormDefinitionPayload = {
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
@@ -102,7 +92,7 @@ class HumanInputFormApi(Resource):
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
# TODO(QuantumGhost): forbid submission for form tokens
# TODO(QuantumGhost): forbid submision for form tokens
# that are only for console.
form = service.get_form_by_token(form_token)

View File

@@ -1,10 +1,7 @@
import logging
from flask import make_response, request
from flask_restx import Resource
from jwt import InvalidTokenError
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
@@ -23,7 +20,7 @@ from controllers.console.wraps import (
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import EmailStr
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
@@ -32,11 +29,9 @@ from libs.token import (
)
from services.account_service import AccountService
from services.app_service import AppService
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
from services.entities.auth_entities import LoginPayloadBase
from services.webapp_auth_service import WebAppAuthService
logger = logging.getLogger(__name__)
class LoginPayload(LoginPayloadBase):
@field_validator("password")
@@ -81,18 +76,14 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
payload = LoginPayload.model_validate(web_ns.payload or {})
normalized_email = payload.email.lower()
try:
account = WebAppAuthService.authenticate(payload.email, payload.password)
except services.errors.account.AccountLoginError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
raise AuthenticationFailedError()
except services.errors.account.AccountNotFoundError:
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@@ -221,30 +212,21 @@ class EmailCodeLoginApi(Resource):
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
if token_data is None:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
raise InvalidTokenError()
token_email = token_data.get("email")
if not isinstance(token_email, str):
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
raise InvalidEmailError()
if token_data["code"] != payload.code:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(payload.token)
try:
account = WebAppAuthService.get_user_through_email(token_email)
except Unauthorized as exc:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
raise AccountBannedError() from exc
account = WebAppAuthService.get_user_through_email(token_email)
if not account:
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
@@ -252,12 +234,3 @@ class EmailCodeLoginApi(Resource):
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
logger.warning(
"Web login failed: email=%s reason=%s ip_address=%s",
email,
reason,
extract_remote_ip(request),
)

View File

@@ -3,10 +3,10 @@ from typing import Literal
from flask import request
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field, TypeAdapter
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
from controllers.common.controller_schemas import MessageFeedbackPayload
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
@@ -25,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -40,6 +41,19 @@ from services.message_service import MessageService
logger = logging.getLogger(__name__)
class MessageListQuery(BaseModel):
conversation_id: str = Field(description="Conversation UUID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageMoreLikeThisQuery(BaseModel):
response_mode: Literal["blocking", "streaming"] = Field(
description="Response mode",

View File

@@ -1,6 +1,5 @@
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from flask import make_response, request
from flask_restx import Resource
@@ -104,23 +103,21 @@ class PassportResource(Resource):
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
def decode_enterprise_webapp_user_id(jwt_token: str | None):
"""
Decode the enterprise user session from the Authorization header.
"""
if not jwt_token:
return None
decoded: dict[str, Any] = PassportService().verify(jwt_token)
decoded = PassportService().verify(jwt_token)
source = decoded.get("token_source")
if not source or source != "webapp_login_token":
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
return decoded
def exchange_token_for_existing_web_user(
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
):
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
"""
Exchange a token for an existing web user session.
"""

View File

@@ -1,4 +1,4 @@
from typing import Any, cast
from typing import cast
from flask_restx import fields, marshal, marshal_with
from sqlalchemy import select
@@ -113,12 +113,12 @@ class AppSiteInfo:
}
def serialize_site(site: Site) -> dict[str, Any]:
def serialize_site(site: Site) -> dict:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
return cast(dict, marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))

View File

@@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
identity: AgentStrategyIdentity
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: dict[str, Any] | None = None
output_schema: dict | None = None
features: list[AgentFeature] | None = None
meta_version: str | None = None
# pydantic configs

View File

@@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
@classmethod
def validate_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> tuple[dict[str, Any], list[str]]:
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
) -> tuple[dict, list[str]]:
if not config.get("sensitive_word_avoidance"):
config["sensitive_word_avoidance"] = {"enabled": False}

View File

@@ -138,9 +138,7 @@ class DatasetConfigManager:
)
@classmethod
def validate_and_set_defaults(
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for dataset feature
@@ -174,7 +172,7 @@ class DatasetConfigManager:
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
@classmethod
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
"""
Extract dataset config for legacy compatibility

View File

@@ -41,7 +41,7 @@ class ModelConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
"""
Validate and set defaults for model config
@@ -108,7 +108,7 @@ class ModelConfigManager:
return dict(config), ["model"]
@classmethod
def validate_model_completion_params(cls, cp: dict[str, Any]):
def validate_model_completion_params(cls, cp: dict):
# model.completion_params
if not isinstance(cp, dict):
raise ValueError("model.completion_params must be of object type")

View File

@@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
)
@classmethod
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
"""
Validate pre_prompt and set defaults for prompt feature
depending on the config['model']
@@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
@classmethod
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
def validate_post_prompt_and_set_defaults(cls, config: dict):
"""
Validate post_prompt and set defaults for prompt feature

View File

@@ -1,5 +1,5 @@
import re
from typing import Any, cast
from typing import cast
from graphon.variables.input_entities import VariableEntity, VariableEntityType
@@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for user input form
@@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
return config, related_config_keys
@classmethod
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for user input form
@@ -164,9 +164,7 @@ class BasicVariablesConfigManager:
return config, ["user_input_form"]
@classmethod
def validate_external_data_tools_and_set_defaults(
cls, tenant_id: str, config: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for external data fetch feature

View File

@@ -30,7 +30,7 @@ class FileUploadConfigManager:
return FileUploadConfig.model_validate(file_upload_dict)
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for file upload feature

View File

@@ -1,5 +1,3 @@
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, ValidationError
@@ -15,7 +13,7 @@ class AppConfigModel(BaseModel):
class MoreLikeThisConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]) -> bool:
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
@@ -25,7 +23,7 @@ class MoreLikeThisConfigManager:
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
try:
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
except ValidationError:

View File

@@ -1,9 +1,6 @@
from typing import Any
class OpeningStatementConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
def convert(cls, config: dict) -> tuple[str, list]:
"""
Convert model config to model config
@@ -18,7 +15,7 @@ class OpeningStatementConfigManager:
return opening_statement, suggested_questions_list
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for opening statement feature

View File

@@ -1,9 +1,6 @@
from typing import Any
class RetrievalResourceConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]) -> bool:
def convert(cls, config: dict) -> bool:
show_retrieve_source = False
retriever_resource_dict = config.get("retriever_resource")
if retriever_resource_dict:
@@ -13,7 +10,7 @@ class RetrievalResourceConfigManager:
return show_retrieve_source
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for retriever resource feature

View File

@@ -1,9 +1,6 @@
from typing import Any
class SpeechToTextConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]) -> bool:
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
@@ -18,7 +15,7 @@ class SpeechToTextConfigManager:
return speech_to_text
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for speech to text feature

View File

@@ -1,9 +1,6 @@
from typing import Any
class SuggestedQuestionsAfterAnswerConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]) -> bool:
def convert(cls, config: dict) -> bool:
"""
Convert model config to model config
@@ -18,7 +15,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
return suggested_questions_after_answer
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for suggested questions feature

View File

@@ -1,11 +1,9 @@
from typing import Any
from core.app.app_config.entities import TextToSpeechEntity
class TextToSpeechConfigManager:
@classmethod
def convert(cls, config: dict[str, Any]):
def convert(cls, config: dict):
"""
Convert model config to model config
@@ -24,7 +22,7 @@ class TextToSpeechConfigManager:
return text_to_speech
@classmethod
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
"""
Validate and set defaults for text to speech feature

View File

@@ -1,5 +1,5 @@
from collections.abc import Generator
from typing import Any, cast
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
from core.app.entities.task_entities import (
@@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
_blocking_response_type = WorkflowAppBlockingResponse
@classmethod
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking full response.
:param blocking_response: blocking response
@@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
"""
Convert blocking simple response.
:param blocking_response: blocking response

View File

@@ -1,5 +1,3 @@
from typing import Any
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
@@ -36,9 +34,7 @@ class PipelineConfigManager(BaseAppConfigManager):
return pipeline_config
@classmethod
def config_validate(
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
) -> dict[str, Any]:
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
"""
Validate for pipeline config

View File

@@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: dict[str, Any] | None = None,
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.

View File

@@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict[str, Any] = Field(default_factory=dict)
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
title: str
index: int
created_at: int
extras: dict[str, Any] = Field(default_factory=dict)
extras: dict = Field(default_factory=dict)
event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
@@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict[str, Any] | None = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
@@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
node_type: str
title: str
created_at: int
extras: dict[str, Any] = Field(default_factory=dict)
extras: dict = Field(default_factory=dict)
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
@@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: dict[str, Any] | None = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus

View File

@@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
description: I18nObject
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: dict[str, Any] | None = None
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
@@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
icon: str | dict
label: I18nObjectDict
type: str
team_credentials: dict[str, Any] | None
team_credentials: dict | None
is_team_authorization: bool
allow_delete: bool
datasources: list[Any]
@@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: dict[str, Any] | None = None
original_credentials: dict[str, Any] | None = None
masked_credentials: dict | None = None
original_credentials: dict | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")

View File

@@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: dict[str, Any] | None = None
output_schema: dict | None = None
@field_validator("parameters", mode="before")
@classmethod
@@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict[str, Any] | None = None
tool_config: dict | None = None
@classmethod
def empty(cls) -> DatasourceInvokeMeta:
@@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
page_icon: dict | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: str | None = Field(None, description="The parent page id")
@@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
Get website crawl request
"""
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
crawl_parameters: dict = Field(..., description="The crawl parameters")
class WebSiteInfoDetail(BaseModel):
@@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
@@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):

View File

@@ -1,5 +1,3 @@
from typing import Any
from pydantic import BaseModel, Field, field_validator
@@ -39,7 +37,7 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict[str, Any] | None = None
data_source_info: dict | None = None
name: str
indexing_status: str
error: str | None = None

View File

@@ -6,7 +6,6 @@ import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Any
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from graphon.model_runtime.entities.provider_entities import (
@@ -112,7 +111,7 @@ class ProviderConfiguration(BaseModel):
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
Get current credentials.
@@ -234,7 +233,7 @@ class ProviderConfiguration(BaseModel):
return session.execute(stmt).scalar_one_or_none()
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@@ -298,7 +297,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
"""
Get provider credentials.
@@ -318,9 +317,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
"""
Validate custom credentials.
:param credentials: provider credentials
@@ -450,7 +447,7 @@ class ProviderConfiguration(BaseModel):
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@@ -518,7 +515,7 @@ class ProviderConfiguration(BaseModel):
def update_provider_credential(
self,
credentials: dict[str, Any],
credentials: dict,
credential_id: str,
credential_name: str | None,
):
@@ -763,7 +760,7 @@ class ProviderConfiguration(BaseModel):
def _get_specific_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str
) -> dict[str, Any] | None:
) -> dict | None:
"""
Get a specific provider credential by ID.
:param credential_id: Credential ID
@@ -835,9 +832,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> dict[str, Any] | None:
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
"""
Get custom model credentials.
@@ -877,7 +872,7 @@ class ProviderConfiguration(BaseModel):
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credentials: dict,
credential_id: str = "",
session: Session | None = None,
):
@@ -944,7 +939,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@@ -1007,12 +1002,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
credential_name: str | None,
credential_id: str,
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@@ -1422,9 +1412,7 @@ class ProviderConfiguration(BaseModel):
# Get model instance of LLM
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
def get_model_schema(
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
) -> AIModelEntity | None:
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
"""
Get model schema
"""
@@ -1483,7 +1471,7 @@ class ProviderConfiguration(BaseModel):
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from enum import StrEnum, auto
from typing import Any, Union
from typing import Union
from graphon.model_runtime.entities.model_entities import ModelType
from pydantic import BaseModel, ConfigDict, Field
@@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
enabled: bool
current_quota_type: ProviderQuotaType | None = None
quota_configurations: list[QuotaConfiguration] = []
credentials: dict[str, Any] | None = None
credentials: dict | None = None
class CustomProviderConfiguration(BaseModel):
@@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
Model class for provider custom configuration.
"""
credentials: dict[str, Any]
credentials: dict
current_credential_id: str | None = None
current_credential_name: str | None = None
available_credentials: list[CredentialConfiguration] = []
@@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict[str, Any] | None
credentials: dict | None
current_credential_id: str | None = None
current_credential_name: str | None = None
available_model_credentials: list[CredentialConfiguration] = []
@@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str
name: str
credentials: dict[str, Any]
credentials: dict
credential_source_type: str | None = None
credential_id: str | None = None

View File

@@ -1,4 +1,4 @@
from typing import Any, cast
from typing import cast
import httpx
@@ -14,7 +14,7 @@ class APIBasedExtensionRequestor:
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]:
def request(self, point: APIBasedExtensionPoint, params: dict):
"""
Request the api.
@@ -49,4 +49,4 @@ class APIBasedExtensionRequestor:
if response.status_code != 200:
raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}")
return cast(dict[str, Any], response.json())
return cast(dict, response.json())

View File

@@ -21,8 +21,8 @@ class ExtensionModule(StrEnum):
class ModuleExtension(BaseModel):
extension_class: Any | None = None
name: str
label: dict[str, Any] | None = None
form_schema: list[dict[str, Any]] | None = None
label: dict | None = None
form_schema: list | None = None
builtin: bool = True
position: int | None = None

View File

@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None:
def validate_config(cls, name: str, tenant_id: str, config: dict):
"""
Validate the incoming form config data.

View File

@@ -1,7 +1,6 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@@ -16,7 +15,7 @@ class ProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> dict[str, Any] | None:
def get(self) -> dict | None:
"""
Get cached model provider credentials.
@@ -34,7 +33,7 @@ class ProviderCredentialsCache:
else:
return None
def set(self, credentials: dict[str, Any]):
def set(self, credentials: dict):
"""
Cache model provider credentials.

View File

@@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
"""Generate cache key based on subclass implementation"""
pass
def get(self) -> dict[str, Any] | None:
def get(self) -> dict | None:
"""Get cached provider credentials"""
cached_credentials = redis_client.get(self.cache_key)
if cached_credentials:
@@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
class NoOpProviderCredentialCache:
"""No-op provider credential cache"""
def get(self) -> dict[str, Any] | None:
def get(self) -> dict | None:
"""Get cached provider credentials"""
return None

View File

@@ -1,7 +1,6 @@
import json
from enum import StrEnum
from json import JSONDecodeError
from typing import Any
from extensions.ext_redis import redis_client
@@ -19,7 +18,7 @@ class ToolParameterCache:
f":identity_id:{identity_id}"
)
def get(self) -> dict[str, Any] | None:
def get(self) -> dict | None:
"""
Get cached model provider credentials.
@@ -37,7 +36,7 @@ class ToolParameterCache:
else:
return None
def set(self, parameters: dict[str, Any]):
def set(self, parameters: dict):
"""Cache model provider credentials."""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

View File

@@ -735,9 +735,7 @@ class IndexingRunner:
@staticmethod
def _update_document_index_status(
document_id: str,
after_indexing_status: IndexingStatus,
extra_update_params: Mapping[Any, Any] | None = None,
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
):
"""
Update the document indexing status.
@@ -764,7 +762,7 @@ class IndexingRunner:
db.session.commit()
@staticmethod
def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]):
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
"""
Update the document segment by document id.
"""

View File

@@ -2,7 +2,7 @@ import json
import logging
import re
from collections.abc import Sequence
from typing import Any, Protocol, TypedDict, cast
from typing import Protocol, TypedDict, cast
import json_repair
from graphon.enums import WorkflowNodeExecutionMetadataKey
@@ -533,7 +533,7 @@ class LLMGenerator:
def __instruction_modify_common(
tenant_id: str,
model_config: ModelConfig,
last_run: dict[str, Any] | None,
last_run: dict | None,
current: str | None,
error_message: str | None,
instruction: str,

View File

@@ -200,9 +200,9 @@ def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,
structured_output_schema: Mapping,
model_parameters: dict[str, Any],
model_parameters: dict,
rules: list[ParameterRule],
) -> dict[str, Any]:
):
"""
Handle structured output for models with native JSON schema support.
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
return model_parameters
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
def _set_response_format(model_parameters: dict, rules: list):
"""
Set the appropriate response format parameter based on model rules.
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
return {"schema": processed_schema, "name": "llm_response"}
def remove_additional_properties(schema: dict[str, Any]) -> None:
def remove_additional_properties(schema: dict):
"""
Remove additionalProperties fields from JSON schema.
Used for models like Gemini that don't support this property.
@@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]) -> None:
remove_additional_properties(item)
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
def convert_boolean_to_string(schema: dict):
"""
Convert boolean type specifications to string in JSON schema.

View File

@@ -77,7 +77,7 @@ class ModelInstance:
@staticmethod
def _get_load_balancing_manager(
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
) -> Optional["LBModelManager"]:
"""
Get load balancing model credentials
@@ -115,7 +115,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
model_parameters: dict | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[True] = True,
@@ -126,7 +126,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any] | None = None,
model_parameters: dict | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
@@ -137,7 +137,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any] | None = None,
model_parameters: dict | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@@ -147,7 +147,7 @@ class ModelInstance:
def invoke_llm(
self,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict[str, Any] | None = None,
model_parameters: dict | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Sequence[str] | None = None,
stream: bool = True,
@@ -528,7 +528,7 @@ class LBModelManager:
model_type: ModelType,
model: str,
load_balancing_configs: list[ModelLoadBalancingConfiguration],
managed_credentials: dict[str, Any] | None = None,
managed_credentials: dict | None = None,
):
"""
Load balancing model manager

View File

@@ -1,5 +1,3 @@
from typing import Any
from pydantic import BaseModel, Field
from sqlalchemy import select
@@ -12,7 +10,7 @@ from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict[str, Any] = Field(default_factory=dict)
inputs: dict = Field(default_factory=dict)
query: str = ""
@@ -25,7 +23,7 @@ class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.
@@ -43,7 +41,7 @@ class ApiModeration(Moderation):
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@@ -75,7 +73,7 @@ class ApiModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
if self.config is None:
raise ValueError("The config is not set.")
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))

View File

@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
@@ -16,7 +15,7 @@ class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict[str, Any] = Field(default_factory=dict)
inputs: dict = Field(default_factory=dict)
query: str = ""
@@ -34,13 +33,13 @@ class Moderation(Extensible, ABC):
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: dict[str, Any] | None = None):
def __init__(self, app_id: str, tenant_id: str, config: dict | None = None):
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict[str, Any]) -> None:
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
@@ -51,7 +50,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
@@ -76,7 +75,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(cls, config: dict[str, Any], is_preset_response_required: bool):
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool):
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):

View File

@@ -1,5 +1,3 @@
from typing import Any
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
@@ -8,12 +6,12 @@ from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict[str, Any]):
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict):
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]):
def validate_config(cls, name: str, tenant_id: str, config: dict):
"""
Validate the incoming form config data.
@@ -26,7 +24,7 @@ class ModerationFactory:
# FIXME: mypy error, try to fix it instead of using type: ignore
extension_class.validate_config(tenant_id, config) # type: ignore
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review

View File

@@ -8,7 +8,7 @@ class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.
@@ -28,7 +28,7 @@ class KeywordsModeration(Moderation):
if len(keywords_row_len) > 100:
raise ValueError("the number of rows for the keywords must be less than 100")
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@@ -66,7 +66,7 @@ class KeywordsModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict[str, Any], keywords_list: list[str]) -> bool:
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values())
def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool:

View File

@@ -1,5 +1,3 @@
from typing import Any
from graphon.model_runtime.entities.model_entities import ModelType
from core.model_manager import ModelManager
@@ -10,7 +8,7 @@ class OpenAIModeration(Moderation):
name: str = "openai_moderation"
@classmethod
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
def validate_config(cls, tenant_id: str, config: dict):
"""
Validate the incoming form config data.
@@ -20,7 +18,7 @@ class OpenAIModeration(Moderation):
"""
cls._validate_inputs_and_outputs_config(config, True)
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config is None:
@@ -51,7 +49,7 @@ class OpenAIModeration(Moderation):
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
)
def _is_violated(self, inputs: dict[str, Any]):
def _is_violated(self, inputs: dict):
text = "\n".join(str(inputs.values()))
model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id)
model_instance = model_manager.get_model_instance(

View File

@@ -778,7 +778,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
logger.info("[Arize/Phoenix] Failed to construct project URL: %s", str(e), exc_info=True)
raise ValueError(f"[Arize/Phoenix] Failed to construct project URL: {str(e)}")
def _construct_llm_attributes(self, prompts: dict[str, Any] | list[Any] | str | None) -> dict[str, str]:
def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
"""Construct LLM attributes with passed prompts for Arize/Phoenix."""
attributes: dict[str, str] = {}
@@ -797,9 +797,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
path = f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{key}"
set_attribute(path, value)
def set_tool_call_attributes(
message_index: int, tool_index: int, tool_call: dict[str, Any] | object | None
) -> None:
def set_tool_call_attributes(message_index: int, tool_index: int, tool_call: dict | object | None) -> None:
"""Extract and assign tool call details safely."""
if not tool_call:
return

View File

@@ -59,24 +59,6 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
@staticmethod
def _get_completion_start_time(
start_time: datetime | None, time_to_first_token: float | int | None
) -> datetime | None:
"""Convert a relative TTFT value in seconds into Langfuse's absolute completion start time."""
if start_time is None or time_to_first_token is None:
return None
try:
ttft_seconds = float(time_to_first_token)
except (TypeError, ValueError):
return None
if ttft_seconds < 0:
return None
return start_time + timedelta(seconds=ttft_seconds)
def trace(self, trace_info: BaseTraceInfo):
if isinstance(trace_info, WorkflowTraceInfo):
self.workflow_trace(trace_info)
@@ -207,18 +189,10 @@ class LangFuseDataTrace(BaseTraceInstance):
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
completion_start_time = None
try:
usage_data = process_data.get("usage")
if not isinstance(usage_data, dict):
usage_data = outputs.get("usage")
if not isinstance(usage_data, dict):
usage_data = {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
completion_start_time = self._get_completion_start_time(
created_at, usage_data.get("time_to_first_token")
)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@@ -236,7 +210,6 @@ class LangFuseDataTrace(BaseTraceInstance):
trace_id=trace_id,
model=process_data.get("model_name"),
start_time=created_at,
completion_start_time=completion_start_time,
end_time=finished_at,
input=inputs,
output=outputs,
@@ -317,16 +290,11 @@ class LangFuseDataTrace(BaseTraceInstance):
unit=UnitEnum.TOKENS,
totalCost=message_data.total_price,
)
completion_start_time = self._get_completion_start_time(
trace_info.start_time,
trace_info.gen_ai_server_time_to_first_token,
)
langfuse_generation_data = LangfuseGeneration(
name="llm",
trace_id=trace_id,
start_time=trace_info.start_time,
completion_start_time=completion_start_time,
end_time=trace_info.end_time,
model=message_data.model_id,
input=trace_info.inputs,

View File

@@ -242,7 +242,7 @@ class MLflowDataTrace(BaseTraceInstance):
return inputs, attributes
def _parse_knowledge_retrieval_outputs(self, outputs: dict[str, Any]):
def _parse_knowledge_retrieval_outputs(self, outputs: dict):
"""Parse KR outputs and attributes from KR workflow node"""
retrieved = outputs.get("result", [])
@@ -319,7 +319,7 @@ class MLflowDataTrace(BaseTraceInstance):
end_time_ns=datetime_to_nanoseconds(trace_info.end_time),
)
def _get_message_user_id(self, metadata: dict[str, Any]) -> str | None:
def _get_message_user_id(self, metadata: dict) -> str | None:
if (end_user_id := metadata.get("from_end_user_id")) and (
end_user_data := db.session.get(EndUser, end_user_id)
):
@@ -468,7 +468,7 @@ class MLflowDataTrace(BaseTraceInstance):
}
return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload]
def _set_trace_metadata(self, span: Span, metadata: dict[str, Any]):
def _set_trace_metadata(self, span: Span, metadata: dict):
token = None
try:
# NB: Set span in context such that we can use update_current_trace() API
@@ -490,7 +490,7 @@ class MLflowDataTrace(BaseTraceInstance):
return messages
return prompts # Fallback to original format
def _parse_single_message(self, item: dict[str, Any]):
def _parse_single_message(self, item: dict):
"""Postprocess single message format to be standard chat message"""
role = item.get("role", "user")
msg = {"role": role, "content": item.get("text", "")}

View File

@@ -3,7 +3,7 @@ import logging
import os
import uuid
from datetime import datetime, timedelta
from typing import Any, cast
from typing import cast
from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey
from opik import Opik, Trace
@@ -436,7 +436,7 @@ class OpikDataTrace(BaseTraceInstance):
self.add_span(span_data)
def add_trace(self, opik_trace_data: dict[str, Any]) -> Trace:
def add_trace(self, opik_trace_data: dict) -> Trace:
try:
trace = self.opik_client.trace(**opik_trace_data)
logger.debug("Opik Trace created successfully")
@@ -444,7 +444,7 @@ class OpikDataTrace(BaseTraceInstance):
except Exception as e:
raise ValueError(f"Opik Failed to create trace: {str(e)}")
def add_span(self, opik_span_data: dict[str, Any]):
def add_span(self, opik_span_data: dict):
try:
self.opik_client.span(**opik_span_data)
logger.debug("Opik Span created successfully")

View File

@@ -324,7 +324,7 @@ class OpsTraceManager:
@classmethod
def encrypt_tracing_config(
cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any], current_trace_config=None
cls, tenant_id: str, tracing_provider: str, tracing_config: dict, current_trace_config=None
):
"""
Encrypt tracing config.
@@ -363,7 +363,7 @@ class OpsTraceManager:
return encrypted_config.model_dump()
@classmethod
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict[str, Any]):
def decrypt_tracing_config(cls, tenant_id: str, tracing_provider: str, tracing_config: dict):
"""
Decrypt tracing config
:param tenant_id: tenant id
@@ -408,7 +408,7 @@ class OpsTraceManager:
return dict(decrypted_config)
@classmethod
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict[str, Any]):
def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
"""
Decrypt tracing config
:param tracing_provider: tracing provider
@@ -581,7 +581,7 @@ class OpsTraceManager:
return app_trace_config
@staticmethod
def check_trace_config_is_effective(tracing_config: dict[str, Any], tracing_provider: str):
def check_trace_config_is_effective(tracing_config: dict, tracing_provider: str):
"""
Check trace config is effective
:param tracing_config: tracing config
@@ -596,7 +596,7 @@ class OpsTraceManager:
return trace_instance(config).api_check()
@staticmethod
def get_trace_config_project_key(tracing_config: dict[str, Any], tracing_provider: str):
def get_trace_config_project_key(tracing_config: dict, tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@@ -611,7 +611,7 @@ class OpsTraceManager:
return trace_instance(config).get_project_key()
@staticmethod
def get_trace_config_project_url(tracing_config: dict[str, Any], tracing_provider: str):
def get_trace_config_project_url(tracing_config: dict, tracing_provider: str):
"""
get trace config is project key
:param tracing_config: tracing config
@@ -1322,8 +1322,8 @@ class TraceTask:
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict[str, Any]:
node_data: dict[str, Any] = kwargs.get("node_execution_data", {})
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
@@ -1431,7 +1431,7 @@ class TraceTask:
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict[str, Any]:
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, model_validator
@@ -32,7 +31,7 @@ class EndpointEntity(BasePluginEntity):
entity of an endpoint
"""
settings: dict[str, Any]
settings: dict
tenant_id: str
plugin_id: str
expired_at: datetime

View File

@@ -1,5 +1,3 @@
from typing import Any
from graphon.model_runtime.entities.provider_entities import ProviderEntity
from pydantic import BaseModel, Field, computed_field, model_validator
@@ -42,7 +40,7 @@ class MarketplacePluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def transform_declaration(cls, data: dict[str, Any]) -> dict[str, Any]:
def transform_declaration(cls, data: dict):
if "endpoint" in data and not data["endpoint"]:
del data["endpoint"]
if "model" in data and not data["model"]:

View File

@@ -123,7 +123,7 @@ class PluginDeclaration(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_category(cls, values: dict[str, Any]) -> dict[str, Any]:
def validate_category(cls, values: dict):
# auto detect category
if values.get("tool"):
values["category"] = PluginCategory.Tool

View File

@@ -73,7 +73,7 @@ class PluginBasicBooleanResponse(BaseModel):
"""
result: bool
credentials: dict[str, Any] | None = None
credentials: dict | None = None
class PluginModelSchemaEntity(BaseModel):

View File

@@ -49,7 +49,7 @@ class RequestInvokeTool(BaseModel):
tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str
tool: str
tool_parameters: dict[str, Any]
tool_parameters: dict
credential_id: str | None = None
@@ -209,7 +209,7 @@ class RequestInvokeEncrypt(BaseModel):
opt: Literal["encrypt", "decrypt", "clear"]
namespace: Literal["endpoint"]
identity: str
data: dict[str, Any] = Field(default_factory=dict)
data: dict = Field(default_factory=dict)
config: list[BasicProviderConfig] = Field(default_factory=list)

View File

@@ -26,7 +26,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@@ -68,7 +68,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
def transformer(json_response: dict[str, Any]) -> dict:
if json_response.get("data"):
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
@@ -110,7 +110,7 @@ class PluginDatasourceManager(BasePluginClient):
tool_provider_id = DatasourceProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict[str, Any]:
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for datasource in data.get("declaration", {}).get("datasources", []):

View File

@@ -1,5 +1,3 @@
from typing import Any
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
@@ -7,12 +5,7 @@ from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self,
tenant_id: str,
user_id: str,
plugin_unique_identifier: str,
name: str,
settings: dict[str, Any],
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
) -> bool:
"""
Create an endpoint for the given plugin.
@@ -56,9 +49,7 @@ class PluginEndpointClient(BasePluginClient):
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
)
def update_endpoint(
self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict[str, Any]
) -> bool:
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
"""
Update the settings of the given endpoint.
"""

View File

@@ -50,7 +50,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
) -> AIModelEntity | None:
"""
Get model schema
@@ -80,7 +80,7 @@ class PluginModelClient(BasePluginClient):
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict[str, Any]
self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
@@ -118,7 +118,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
) -> bool:
"""
validate the credentials of the provider
@@ -157,9 +157,9 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any] | None = None,
model_parameters: dict | None = None,
tools: list[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = True,
@@ -206,7 +206,7 @@ class PluginModelClient(BasePluginClient):
provider: str,
model_type: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None,
) -> int:
@@ -248,7 +248,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
texts: list[str],
input_type: str,
) -> EmbeddingResult:
@@ -290,7 +290,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
documents: list[dict],
input_type: str,
) -> EmbeddingResult:
@@ -332,7 +332,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
texts: list[str],
) -> list[int]:
"""
@@ -372,7 +372,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
query: str,
docs: list[str],
score_threshold: float | None = None,
@@ -418,7 +418,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
query: MultimodalRerankInput,
docs: list[MultimodalRerankInput],
score_threshold: float | None = None,
@@ -463,7 +463,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
@@ -508,7 +508,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
language: str | None = None,
):
"""
@@ -552,7 +552,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
file: IO[bytes],
) -> str:
"""
@@ -592,7 +592,7 @@ class PluginModelClient(BasePluginClient):
plugin_id: str,
provider: str,
model: str,
credentials: dict[str, Any],
credentials: dict,
text: str,
) -> bool:
"""

Some files were not shown because too many files have changed in this diff Show More