mirror of
https://github.com/langgenius/dify.git
synced 2026-04-14 20:42:39 +00:00
Compare commits
40 Commits
feat/tidb-
...
feat/new-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9ee897973 | ||
|
|
971828615e | ||
|
|
b804c7ed47 | ||
|
|
c7a7c73034 | ||
|
|
94b3087b98 | ||
|
|
3e0578a1c6 | ||
|
|
5f87239abc | ||
|
|
c03b25a940 | ||
|
|
90cce7693f | ||
|
|
77c182f738 | ||
|
|
e04f00d29b | ||
|
|
bbed99a4cb | ||
|
|
df6c1064c6 | ||
|
|
f4e04fc872 | ||
|
|
59b9221501 | ||
|
|
218c10ba4f | ||
|
|
4c878da9e6 | ||
|
|
698af54c4f | ||
|
|
10bb276e97 | ||
|
|
73fd439541 | ||
|
|
5cdae671d5 | ||
|
|
e50c36526e | ||
|
|
2de2a8fd3a | ||
|
|
e2e16772a1 | ||
|
|
b21a443d56 | ||
|
|
4f010cd4f5 | ||
|
|
3d4be88d97 | ||
|
|
482a004efe | ||
|
|
7052257c8d | ||
|
|
edfcab6455 | ||
|
|
66212e3575 | ||
|
|
96374d7f6a | ||
|
|
44491e427c | ||
|
|
d3d9f21cdf | ||
|
|
0c7e7e0c4e | ||
|
|
d9d1e9b63a | ||
|
|
bebafaa346 | ||
|
|
1835a1dc5d | ||
|
|
8f3a3ea03e | ||
|
|
96641a93f6 |
@@ -1,79 +0,0 @@
|
||||
---
|
||||
name: e2e-cucumber-playwright
|
||||
description: Write, update, or review Dify end-to-end tests under `e2e/` that use Cucumber, Gherkin, and Playwright. Use when the task involves `.feature` files, `features/step-definitions/`, `features/support/`, `DifyWorld`, scenario tags, locator/assertion choices, or E2E testing best practices for this repository.
|
||||
---
|
||||
|
||||
# Dify E2E Cucumber + Playwright
|
||||
|
||||
Use this skill for Dify's repository-level E2E suite in `e2e/`. Use [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) as the canonical guide for local architecture and conventions, then apply Playwright/Cucumber best practices only where they fit the current suite.
|
||||
|
||||
## Scope
|
||||
|
||||
- Use this skill for `.feature` files, Cucumber step definitions, `DifyWorld`, hooks, tags, and E2E review work under `e2e/`.
|
||||
- Do not use this skill for Vitest or React Testing Library work under `web/`; use `frontend-testing` instead.
|
||||
- Do not use this skill for backend test or API review tasks under `api/`.
|
||||
|
||||
## Read Order
|
||||
|
||||
1. Read [`e2e/AGENTS.md`](../../../e2e/AGENTS.md) first.
|
||||
2. Read only the files directly involved in the task:
|
||||
- target `.feature` files under `e2e/features/`
|
||||
- related step files under `e2e/features/step-definitions/`
|
||||
- `e2e/features/support/hooks.ts` and `e2e/features/support/world.ts` when session lifecycle or shared state matters
|
||||
- `e2e/scripts/run-cucumber.ts` and `e2e/cucumber.config.ts` when tags or execution flow matter
|
||||
3. Read [`references/playwright-best-practices.md`](references/playwright-best-practices.md) only when locator, assertion, isolation, or waiting choices are involved.
|
||||
4. Read [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md) only when scenario wording, step granularity, tags, or expression design are involved.
|
||||
5. Re-check official docs with Context7 before introducing a new Playwright or Cucumber pattern.
|
||||
|
||||
## Local Rules
|
||||
|
||||
- `e2e/` uses Cucumber for scenarios and Playwright as the browser layer.
|
||||
- `DifyWorld` is the per-scenario context object. Type `this` as `DifyWorld` and use `async function`, not arrow functions.
|
||||
- Keep glue organized by capability under `e2e/features/step-definitions/`; use `common/` only for broadly reusable steps.
|
||||
- Browser session behavior comes from `features/support/hooks.ts`:
|
||||
- default: authenticated session with shared storage state
|
||||
- `@unauthenticated`: clean browser context
|
||||
- `@authenticated`: readability/selective-run tag only unless implementation changes
|
||||
- `@fresh`: only for `e2e:full*` flows
|
||||
- Do not import Playwright Test runner patterns that bypass the current Cucumber + `DifyWorld` architecture unless the task is explicitly about changing that architecture.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Rebuild local context.
|
||||
- Inspect the target feature area.
|
||||
- Reuse an existing step when wording and behavior already match.
|
||||
- Add a new step only for a genuinely new user action or assertion.
|
||||
- Keep edits close to the current capability folder unless the step is broadly reusable.
|
||||
2. Write behavior-first scenarios.
|
||||
- Describe user-observable behavior, not DOM mechanics.
|
||||
- Keep each scenario focused on one workflow or outcome.
|
||||
- Keep scenarios independent and re-runnable.
|
||||
3. Write step definitions in the local style.
|
||||
- Keep one step to one user-visible action or one assertion.
|
||||
- Prefer Cucumber Expressions such as `{string}` and `{int}`.
|
||||
- Scope locators to stable containers when the page has repeated elements.
|
||||
- Avoid page-object layers or extra helper abstractions unless repeated complexity clearly justifies them.
|
||||
4. Use Playwright in the local style.
|
||||
- Prefer user-facing locators: `getByRole`, `getByLabel`, `getByPlaceholder`, `getByText`, then `getByTestId` for explicit contracts.
|
||||
- Use web-first `expect(...)` assertions.
|
||||
- Do not use `waitForTimeout`, manual polling, or raw visibility checks when a locator action or retrying assertion already expresses the behavior.
|
||||
5. Validate narrowly.
|
||||
- Run the narrowest tagged scenario or flow that exercises the change.
|
||||
- Run `pnpm -C e2e check`.
|
||||
- Broaden verification only when the change affects hooks, tags, setup, or shared step semantics.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
- Does the scenario describe behavior rather than implementation?
|
||||
- Does it fit the current session model, tags, and `DifyWorld` usage?
|
||||
- Should an existing step be reused instead of adding a new one?
|
||||
- Are locators user-facing and assertions web-first?
|
||||
- Does the change introduce hidden coupling across scenarios, tags, or instance state?
|
||||
- Does it document or implement behavior that differs from the real hooks or configuration?
|
||||
|
||||
Lead findings with correctness, flake risk, and architecture drift.
|
||||
|
||||
## References
|
||||
|
||||
- [`references/playwright-best-practices.md`](references/playwright-best-practices.md)
|
||||
- [`references/cucumber-best-practices.md`](references/cucumber-best-practices.md)
|
||||
@@ -1,4 +0,0 @@
|
||||
interface:
|
||||
display_name: "E2E Cucumber + Playwright"
|
||||
short_description: "Write and review Dify E2E scenarios."
|
||||
default_prompt: "Use $e2e-cucumber-playwright to write or review a Dify E2E scenario under e2e/."
|
||||
@@ -1,93 +0,0 @@
|
||||
# Cucumber Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing Gherkin scenarios, step definitions, parameter expressions, and step reuse in Dify's `e2e/` suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://cucumber.io/docs/guides/10-minute-tutorial/
|
||||
- https://cucumber.io/docs/cucumber/step-definitions/
|
||||
- https://cucumber.io/docs/cucumber/cucumber-expressions/
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Treat scenarios as executable specifications
|
||||
|
||||
Cucumber scenarios should describe examples of behavior, not test implementation recipes.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- write what the user does and what should happen
|
||||
- avoid UI-internal wording such as selector details, DOM structure, or component names
|
||||
- keep language concrete enough that the scenario reads like living documentation
|
||||
|
||||
### 2. Keep scenarios focused
|
||||
|
||||
A scenario should usually prove one workflow or business outcome. If a scenario wanders across several unrelated behaviors, split it.
|
||||
|
||||
In Dify's suite, this means:
|
||||
|
||||
- one capability-focused scenario per feature path
|
||||
- no long setup chains when existing bootstrap or reusable steps already cover them
|
||||
- no hidden dependency on another scenario's side effects
|
||||
|
||||
### 3. Reuse steps, but only when behavior really matches
|
||||
|
||||
Good reuse reduces duplication. Bad reuse hides meaning.
|
||||
|
||||
Prefer reuse when:
|
||||
|
||||
- the user action is genuinely the same
|
||||
- the expected outcome is genuinely the same
|
||||
- the wording stays natural across features
|
||||
|
||||
Write a new step when:
|
||||
|
||||
- the behavior is materially different
|
||||
- reusing the old wording would make the scenario misleading
|
||||
- a supposedly generic step would become an implementation-detail wrapper
|
||||
|
||||
### 4. Prefer Cucumber Expressions
|
||||
|
||||
Use Cucumber Expressions for parameters unless regex is clearly necessary.
|
||||
|
||||
Common examples:
|
||||
|
||||
- `{string}` for labels, names, and visible text
|
||||
- `{int}` for counts
|
||||
- `{float}` for decimal values
|
||||
- `{word}` only when the value is truly a single token
|
||||
|
||||
Keep expressions readable. If a step needs complicated parsing logic, first ask whether the scenario wording should be simpler.
|
||||
|
||||
### 5. Keep step definitions thin and meaningful
|
||||
|
||||
Step definitions are glue between Gherkin and automation, not a second abstraction language.
|
||||
|
||||
For Dify:
|
||||
|
||||
- type `this` as `DifyWorld`
|
||||
- use `async function`
|
||||
- keep each step to one user-visible action or assertion
|
||||
- rely on `DifyWorld` and existing support code for shared context
|
||||
- avoid leaking cross-scenario state
|
||||
|
||||
### 6. Use tags intentionally
|
||||
|
||||
Tags should communicate run scope or session semantics, not become ad hoc metadata.
|
||||
|
||||
In Dify's current suite:
|
||||
|
||||
- capability tags group related scenarios
|
||||
- `@unauthenticated` changes session behavior
|
||||
- `@authenticated` is descriptive/selective, not a behavior switch by itself
|
||||
- `@fresh` belongs to reset/full-install flows only
|
||||
|
||||
If a proposed tag implies behavior, verify that hooks or runner configuration actually implement it.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Does the scenario read like a real example of product behavior?
|
||||
- Are the steps behavior-oriented instead of implementation-oriented?
|
||||
- Is a reused step still truthful in this feature?
|
||||
- Is a new tag documenting real behavior, or inventing semantics that the suite does not implement?
|
||||
- Would a new reader understand the outcome without opening the step-definition file?
|
||||
@@ -1,96 +0,0 @@
|
||||
# Playwright Best Practices For Dify E2E
|
||||
|
||||
Use this reference when writing or reviewing locator, assertion, isolation, or synchronization logic for Dify's Cucumber-based E2E suite.
|
||||
|
||||
Official sources:
|
||||
|
||||
- https://playwright.dev/docs/best-practices
|
||||
- https://playwright.dev/docs/locators
|
||||
- https://playwright.dev/docs/test-assertions
|
||||
- https://playwright.dev/docs/browser-contexts
|
||||
|
||||
## What Matters Most
|
||||
|
||||
### 1. Keep scenarios isolated
|
||||
|
||||
Playwright's model is built around clean browser contexts so one test does not leak into another. In Dify's suite, that principle maps to per-scenario session setup in `features/support/hooks.ts` and `DifyWorld`.
|
||||
|
||||
Apply it like this:
|
||||
|
||||
- do not depend on another scenario having run first
|
||||
- do not persist ad hoc scenario state outside `DifyWorld`
|
||||
- do not couple ordinary scenarios to `@fresh` behavior
|
||||
- when a flow needs special auth/session semantics, express that through the existing tag model or explicit hook changes
|
||||
|
||||
### 2. Prefer user-facing locators
|
||||
|
||||
Playwright recommends built-in locators that reflect what users perceive on the page.
|
||||
|
||||
Preferred order in this repository:
|
||||
|
||||
1. `getByRole`
|
||||
2. `getByLabel`
|
||||
3. `getByPlaceholder`
|
||||
4. `getByText`
|
||||
5. `getByTestId` when an explicit test contract is the most stable option
|
||||
|
||||
Avoid raw CSS/XPath selectors unless no stable user-facing contract exists and adding one is not practical.
|
||||
|
||||
Also remember:
|
||||
|
||||
- repeated content usually needs scoping to a stable container
|
||||
- exact text matching is often too brittle when role/name or label already exists
|
||||
- `getByTestId` is acceptable when semantics are weak but the contract is intentional
|
||||
|
||||
### 3. Use web-first assertions
|
||||
|
||||
Playwright assertions auto-wait and retry. Prefer them over manual state inspection.
|
||||
|
||||
Prefer:
|
||||
|
||||
- `await expect(page).toHaveURL(...)`
|
||||
- `await expect(locator).toBeVisible()`
|
||||
- `await expect(locator).toBeHidden()`
|
||||
- `await expect(locator).toBeEnabled()`
|
||||
- `await expect(locator).toHaveText(...)`
|
||||
|
||||
Avoid:
|
||||
|
||||
- `expect(await locator.isVisible()).toBe(true)`
|
||||
- custom polling loops for DOM state
|
||||
- `waitForTimeout` as synchronization
|
||||
|
||||
If a condition genuinely needs custom retry logic, use Playwright's polling/assertion tools deliberately and keep that choice local and explicit.
|
||||
|
||||
### 4. Let actions wait for actionability
|
||||
|
||||
Locator actions already wait for the element to be actionable. Do not preface every click/fill with extra timing logic unless the action needs a specific visible/ready assertion for clarity.
|
||||
|
||||
Good pattern:
|
||||
|
||||
- assert a meaningful visible state when that is part of the behavior
|
||||
- then click/fill/select via locator APIs
|
||||
|
||||
Bad pattern:
|
||||
|
||||
- stack arbitrary waits before every action
|
||||
- wait on unstable implementation details instead of the visible state the user cares about
|
||||
|
||||
### 5. Match debugging to the current suite
|
||||
|
||||
Playwright's wider ecosystem supports traces and rich debugging tools. Dify's current suite already captures:
|
||||
|
||||
- full-page screenshots
|
||||
- page HTML
|
||||
- console errors
|
||||
- page errors
|
||||
|
||||
Use the existing artifact flow by default. If a task is specifically about improving diagnostics, confirm the change fits the current Cucumber architecture before importing broader Playwright tooling.
|
||||
|
||||
## Review Questions
|
||||
|
||||
- Would this locator survive DOM refactors that do not change user-visible behavior?
|
||||
- Is this assertion using Playwright's retrying semantics?
|
||||
- Is any explicit wait masking a real readiness problem?
|
||||
- Does this code preserve per-scenario isolation?
|
||||
- Is a new abstraction really needed, or does it bypass the existing `DifyWorld` + step-definition model?
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/e2e-cucumber-playwright
|
||||
7
.github/workflows/docker-build.yml
vendored
7
.github/workflows/docker-build.yml
vendored
@@ -6,7 +6,14 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
|
||||
1
.github/workflows/main-ci.yml
vendored
1
.github/workflows/main-ci.yml
vendored
@@ -92,7 +92,6 @@ jobs:
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'api/tests/integration_tests/vdb/**'
|
||||
- 'api/providers/vdb/*/tests/**'
|
||||
- '.github/workflows/vdb-tests.yml'
|
||||
- '.github/workflows/expose_service_ports.sh'
|
||||
- 'docker/.env.example'
|
||||
|
||||
2
.github/workflows/vdb-tests-full.yml
vendored
2
.github/workflows/vdb-tests-full.yml
vendored
@@ -89,7 +89,7 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
|
||||
10
.github/workflows/vdb-tests.yml
vendored
10
.github/workflows/vdb-tests.yml
vendored
@@ -81,12 +81,12 @@ jobs:
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/providers/vdb/tidb-vector/tests/integration_tests/check_tiflash_ready.py
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/providers/vdb/vdb-chroma/tests/integration_tests \
|
||||
api/providers/vdb/vdb-pgvector/tests/integration_tests \
|
||||
api/providers/vdb/vdb-qdrant/tests/integration_tests \
|
||||
api/providers/vdb/vdb-weaviate/tests/integration_tests
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
|
||||
@@ -69,6 +69,8 @@ ignore = [
|
||||
"FURB152", # math-constant
|
||||
"UP007", # non-pep604-annotation
|
||||
"UP032", # f-string
|
||||
"UP045", # non-pep604-annotation-optional
|
||||
"B005", # strip-with-multi-characters
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
@@ -82,6 +84,7 @@ ignore = [
|
||||
"SIM102", # collapsible-if
|
||||
"SIM103", # needless-bool
|
||||
"SIM105", # suppressible-exception
|
||||
"SIM107", # return-in-try-except-finally
|
||||
"SIM108", # if-else-block-instead-of-if-exp
|
||||
"SIM113", # enumerate-for-loop
|
||||
"SIM117", # multiple-with-statements
|
||||
@@ -90,16 +93,29 @@ ignore = [
|
||||
]
|
||||
|
||||
[lint.per-file-ignores]
|
||||
"__init__.py" = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"graphon/model_runtime/callbacks/base_callback.py" = ["T201"]
|
||||
"core/workflow/callbacks/workflow_logging_callback.py" = ["T201"]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
msg = "Use Pydantic payload/query models instead of reqparse."
|
||||
|
||||
@@ -21,9 +21,8 @@ RUN apt-get update \
|
||||
# for building gmpy2
|
||||
libmpfr-dev libmpc-dev
|
||||
|
||||
# Install Python dependencies (workspace members under providers/vdb/)
|
||||
# Install Python dependencies
|
||||
COPY pyproject.toml uv.lock ./
|
||||
COPY providers ./providers
|
||||
RUN uv sync --locked --no-dev
|
||||
|
||||
# production stage
|
||||
|
||||
@@ -341,10 +341,11 @@ def add_qdrant_index(field: str):
|
||||
click.echo(click.style("No dataset collection bindings found.", fg="red"))
|
||||
return
|
||||
import qdrant_client
|
||||
from dify_vdb_qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import PayloadSchemaType
|
||||
|
||||
from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig
|
||||
|
||||
for binding in bindings:
|
||||
if dify_config.QDRANT_URL is None:
|
||||
raise ValueError("Qdrant URL is required.")
|
||||
|
||||
@@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@@ -341,6 +362,15 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@@ -1274,6 +1304,52 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class AgentV2UpgradeConfig(BaseSettings):
|
||||
"""Feature flags for transparent Agent V2 upgrade."""
|
||||
|
||||
AGENT_V2_TRANSPARENT_UPGRADE: bool = Field(
|
||||
description="Transparently run old apps (chat/completion/agent-chat) through the Agent V2 workflow engine. "
|
||||
"When enabled, old apps synthesize a virtual workflow at runtime instead of using legacy runners.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_V2_REPLACES_LLM: bool = Field(
|
||||
description="Transparently replace LLM nodes in workflows with Agent V2 nodes at runtime. "
|
||||
"LLMNodeData is remapped to AgentV2NodeData with tools=[] (identical behavior).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1343,29 +1419,6 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1376,6 +1429,7 @@ class FeatureConfig(
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@@ -1391,7 +1445,6 @@ class FeatureConfig(
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
@@ -1399,6 +1452,9 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
AgentV2UpgradeConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -160,16 +160,6 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DB_SESSION_TIMEZONE_OVERRIDE: str = Field(
|
||||
description=(
|
||||
"PostgreSQL session timezone override injected via startup options."
|
||||
" Default is 'UTC' for out-of-the-box consistency."
|
||||
" Set to empty string to disable app-level timezone injection, for example when using RDS Proxy"
|
||||
" together with a database-side default timezone."
|
||||
),
|
||||
default="UTC",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def SQLALCHEMY_DATABASE_URI_SCHEME(self) -> str:
|
||||
@@ -237,13 +227,12 @@ class DatabaseConfig(BaseSettings):
|
||||
connect_args: dict[str, str] = {}
|
||||
# Use the dynamic SQLALCHEMY_DATABASE_URI_SCHEME property
|
||||
if self.SQLALCHEMY_DATABASE_URI_SCHEME.startswith("postgresql"):
|
||||
merged_options = options.strip()
|
||||
session_timezone_override = self.DB_SESSION_TIMEZONE_OVERRIDE.strip()
|
||||
if session_timezone_override:
|
||||
timezone_opt = f"-c timezone={session_timezone_override}"
|
||||
merged_options = f"{merged_options} {timezone_opt}".strip() if merged_options else timezone_opt
|
||||
if merged_options:
|
||||
connect_args = {"options": merged_options}
|
||||
timezone_opt = "-c timezone=UTC"
|
||||
if options:
|
||||
merged_options = f"{options} {timezone_opt}"
|
||||
else:
|
||||
merged_options = timezone_opt
|
||||
connect_args = {"options": merged_options}
|
||||
|
||||
result: SQLAlchemyEngineOptionsDict = {
|
||||
"pool_size": self.SQLALCHEMY_POOL_SIZE,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -41,17 +42,17 @@ class HologresConfig(BaseSettings):
|
||||
default="public",
|
||||
)
|
||||
|
||||
HOLOGRES_TOKENIZER: str = Field(
|
||||
HOLOGRES_TOKENIZER: TokenizerType = Field(
|
||||
description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').",
|
||||
default="jieba",
|
||||
)
|
||||
|
||||
HOLOGRES_DISTANCE_METHOD: str = Field(
|
||||
HOLOGRES_DISTANCE_METHOD: DistanceType = Field(
|
||||
description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').",
|
||||
default="Cosine",
|
||||
)
|
||||
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: str = Field(
|
||||
HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field(
|
||||
description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').",
|
||||
default="rabitq",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Configuration for InterSystems IRIS vector database."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, PositiveInt, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@@ -66,7 +64,7 @@ class IrisVectorConfig(BaseSettings):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
"""Validate IRIS configuration values.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -81,4 +81,20 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new agent backed by single-node workflow)
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import login_required
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateArgs, AdvancedPromptTemplateService
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class AdvancedPromptTemplateQuery(BaseModel):
|
||||
@@ -35,10 +35,5 @@ class AdvancedPromptTemplateList(Resource):
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AdvancedPromptTemplateQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
prompt_args: AdvancedPromptTemplateArgs = {
|
||||
"app_mode": args.app_mode,
|
||||
"model_mode": args.model_mode,
|
||||
"model_name": args.model_name,
|
||||
"has_context": args.has_context,
|
||||
}
|
||||
return AdvancedPromptTemplateService.get_prompt(prompt_args)
|
||||
|
||||
return AdvancedPromptTemplateService.get_prompt(args.model_dump())
|
||||
|
||||
@@ -52,7 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@@ -62,7 +62,7 @@ _logger = logging.getLogger(__name__)
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@@ -94,7 +94,9 @@ class AppListQuery(BaseModel):
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
@@ -215,7 +215,7 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -26,13 +26,13 @@ def _to_timestamp(value: datetime | int | None) -> int | None:
|
||||
|
||||
class MCPServerCreatePayload(BaseModel):
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
|
||||
|
||||
class MCPServerUpdatePayload(BaseModel):
|
||||
id: str = Field(..., description="Server ID")
|
||||
description: str | None = Field(default=None, description="Server description")
|
||||
parameters: dict[str, Any] = Field(..., description="Server parameters configuration")
|
||||
parameters: dict = Field(..., description="Server parameters configuration")
|
||||
status: str | None = Field(default=None, description="Server status")
|
||||
|
||||
|
||||
|
||||
@@ -237,7 +237,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
@@ -393,7 +393,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -226,7 +226,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@@ -310,7 +310,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -356,7 +356,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -432,7 +432,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -534,7 +534,7 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -563,7 +563,7 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -718,7 +718,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
@@ -746,7 +746,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@@ -792,7 +792,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -810,7 +810,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -854,7 +854,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -876,7 +876,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, block_type: str):
|
||||
"""
|
||||
@@ -941,7 +941,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -990,7 +990,7 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@@ -1028,7 +1028,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
@@ -1068,7 +1068,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@@ -1103,7 +1103,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
|
||||
@@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
322
api/controllers/console/app/workflow_comment.py
Normal file
322
api/controllers/console/app/workflow_comment.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@@ -216,7 +216,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@@ -36,7 +36,7 @@ from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowR
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunListArgs, WorkflowRunService
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
@@ -207,18 +207,14 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get advanced chat app workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING if not specified
|
||||
triggered_from = (
|
||||
@@ -309,7 +305,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -353,18 +349,14 @@ class WorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args_model = WorkflowRunListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args: WorkflowRunListArgs = {"limit": args_model.limit}
|
||||
if args_model.last_id is not None:
|
||||
args["last_id"] = args_model.last_id
|
||||
if args_model.status is not None:
|
||||
args["status"] = args_model.status
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
# Default to DEBUGGING for workflow if not specified (backward compatibility)
|
||||
triggered_from = (
|
||||
@@ -405,7 +397,7 @@ class WorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -442,7 +434,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
@@ -466,7 +458,7 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
|
||||
@@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = session.scalar(
|
||||
select(WorkflowWebhookTrigger)
|
||||
@@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import logging
|
||||
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -45,13 +42,12 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService, InvitationDetailDict, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.errors.account import AccountRegisterError
|
||||
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@@ -95,12 +91,10 @@ class LoginApi(Resource):
|
||||
normalized_email = request_email.lower()
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
|
||||
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
|
||||
if is_login_error_rate_limit:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.LOGIN_RATE_LIMITED)
|
||||
raise EmailPasswordLoginLimitError()
|
||||
|
||||
invite_token = args.invite_token
|
||||
@@ -116,20 +110,14 @@ class LoginApi(Resource):
|
||||
invitee_email = data.get("email") if data else None
|
||||
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
|
||||
if invitee_email_normalized != normalized_email:
|
||||
_log_console_login_failure(
|
||||
email=normalized_email,
|
||||
reason=LoginFailureReason.INVALID_INVITATION_EMAIL,
|
||||
)
|
||||
raise InvalidEmailError()
|
||||
account = _authenticate_account_with_case_fallback(
|
||||
request_email, normalized_email, args.password, invite_token
|
||||
)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError as exc:
|
||||
AccountService.add_login_error_rate_limit(normalized_email)
|
||||
_log_console_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError() from exc
|
||||
# SELF_HOSTED only have one workspace
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -252,27 +240,20 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = AccountService.get_email_code_login_data(args.token)
|
||||
if token_data is None:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != user_email:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != args.code:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
AccountService.revoke_email_code_login_token(args.token)
|
||||
try:
|
||||
account = _get_account_with_case_fallback(original_email)
|
||||
except Unauthorized as exc:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
if account:
|
||||
tenants = TenantService.get_join_tenants(account)
|
||||
@@ -298,7 +279,6 @@ class EmailCodeLoginApi(Resource):
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
except AccountRegisterError:
|
||||
_log_console_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_IN_FREEZE)
|
||||
raise AccountInFreezeError()
|
||||
except WorkspacesLimitExceededError:
|
||||
raise WorkspacesLimitExceeded()
|
||||
@@ -356,12 +336,3 @@ def _authenticate_account_with_case_fallback(
|
||||
if original_email == normalized_email:
|
||||
raise
|
||||
return AccountService.authenticate(normalized_email, password, invite_token)
|
||||
|
||||
|
||||
def _log_console_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Console login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TypedDict
|
||||
|
||||
from flask import request
|
||||
@@ -14,14 +13,6 @@ from services.billing_service import BillingService
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
class NotificationLangContent(TypedDict, total=False):
|
||||
lang: str
|
||||
title: str
|
||||
subtitle: str
|
||||
body: str
|
||||
titlePicUrl: str
|
||||
|
||||
|
||||
class NotificationItemDict(TypedDict):
|
||||
notification_id: str | None
|
||||
frequency: str | None
|
||||
@@ -37,11 +28,9 @@ class NotificationResponseDict(TypedDict):
|
||||
notifications: list[NotificationItemDict]
|
||||
|
||||
|
||||
def _pick_lang_content(contents: Mapping[str, NotificationLangContent], lang: str) -> NotificationLangContent:
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return (
|
||||
contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), NotificationLangContent())
|
||||
)
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
@@ -82,7 +71,7 @@ class NotificationApi(Resource):
|
||||
|
||||
notifications: list[NotificationItemDict] = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: Mapping[str, NotificationLangContent] = notification.get("contents") or {}
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
item: NotificationItemDict = {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
|
||||
1
api/controllers/console/socketio/__init__.py
Normal file
1
api/controllers/console/socketio/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
119
api/controllers/console/socketio/workflow.py
Normal file
119
api/controllers/console/socketio/workflow.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
9. skill_file_active
|
||||
10. skill_sync_request
|
||||
11. skill_resync_request
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("skill_event")
|
||||
def handle_skill_event(sid, data):
|
||||
"""
|
||||
Handle skill events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_skill_event(sid, data)
|
||||
@@ -35,24 +35,22 @@ def plugin_permission_required(
|
||||
return view(*args, **kwargs)
|
||||
|
||||
if install_required:
|
||||
match permission.install_permission:
|
||||
case TenantPluginPermission.InstallPermission.NOBODY:
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.InstallPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
if debug_required:
|
||||
match permission.debug_permission:
|
||||
case TenantPluginPermission.DebugPermission.NOBODY:
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||
raise Forbidden()
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.DebugPermission.ADMINS:
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
case TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||
pass
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
import pytz
|
||||
from flask import request
|
||||
@@ -174,7 +174,7 @@ reg(CheckEmailUniquePayload)
|
||||
register_schema_models(console_ns, AccountResponse)
|
||||
|
||||
|
||||
def _serialize_account(account) -> dict[str, Any]:
|
||||
def _serialize_account(account) -> dict:
|
||||
return AccountResponse.model_validate(account, from_attributes=True).model_dump(mode="json")
|
||||
|
||||
|
||||
|
||||
67
api/controllers/console/workspace/dsl.py
Normal file
67
api/controllers/console/workspace/dsl.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
class DSLPredictRequest(BaseModel):
|
||||
app_id: str
|
||||
current_node_id: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = DSLPredictRequest.model_validate(request.get_json())
|
||||
|
||||
app_id: str = args.app_id
|
||||
current_node_id: str = args.current_node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@@ -20,7 +20,7 @@ from models.account import AccountStatus
|
||||
from models.dataset import RateLimitLog
|
||||
from models.model import DifySetup
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
from services.operation_service import OperationService, UtmInfo
|
||||
from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
@@ -205,7 +205,7 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
utm_info_dict: dict = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
80
api/controllers/files/storage_files.py
Normal file
80
api/controllers/files/storage_files.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
@@ -2,7 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -158,20 +158,14 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise MCPRequestError(mcp_types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
|
||||
|
||||
def _convert_user_input_form(self, raw_form: list[dict[str, Any]]) -> list[VariableEntity]:
|
||||
def _convert_user_input_form(self, raw_form: list[dict]) -> list[VariableEntity]:
|
||||
"""Convert raw user input form to VariableEntity objects"""
|
||||
return [self._create_variable_entity(item) for item in raw_form]
|
||||
|
||||
def _create_variable_entity(self, item: dict[str, Any]) -> VariableEntity:
|
||||
def _create_variable_entity(self, item: dict) -> VariableEntity:
|
||||
"""Create a single VariableEntity from raw form item"""
|
||||
variable_type_raw: str = item.get("type", "") or list(item.keys())[0]
|
||||
try:
|
||||
variable_type = VariableEntityType(variable_type_raw)
|
||||
except ValueError as e:
|
||||
raise MCPRequestError(
|
||||
mcp_types.INVALID_PARAMS, f"Invalid user_input_form variable type: {variable_type_raw}"
|
||||
) from e
|
||||
variable = item[variable_type_raw]
|
||||
variable_type = item.get("type", "") or list(item.keys())[0]
|
||||
variable = item[variable_type]
|
||||
|
||||
return VariableEntity(
|
||||
type=variable_type,
|
||||
@@ -184,7 +178,7 @@ class MCPAppApi(Resource):
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict[str, Any]) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
"""Parse and validate MCP request"""
|
||||
try:
|
||||
return mcp_types.ClientRequest.model_validate(args)
|
||||
|
||||
@@ -194,7 +194,7 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
@@ -258,7 +258,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@@ -98,7 +98,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@@ -142,7 +142,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -171,7 +171,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -213,7 +213,7 @@ class ConversationVariablesApi(Resource):
|
||||
"""
|
||||
# conversational variable only for chat app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -252,7 +252,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
The value must match the variable's expected type.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@@ -53,7 +53,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@@ -158,7 +158,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@@ -33,25 +33,25 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict[str, Any]:
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict[str, Any]]:
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict[str, str | None] = {}
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict: dict[str, Any] = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,6 @@ Web App Human Input Form APIs.
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@@ -59,19 +58,10 @@ def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class FormDefinitionPayload(TypedDict):
|
||||
form_content: Any
|
||||
inputs: Any
|
||||
resolved_default_values: dict[str, str]
|
||||
user_actions: Any
|
||||
expiration_time: int
|
||||
site: NotRequired[dict]
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload: FormDefinitionPayload = {
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
@@ -102,7 +92,7 @@ class HumanInputFormApi(Resource):
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submission for form tokens
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from jwt import InvalidTokenError
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
@@ -23,7 +20,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import EmailStr
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
@@ -32,11 +29,9 @@ from libs.token import (
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.entities.auth_entities import LoginFailureReason, LoginPayloadBase
|
||||
from services.entities.auth_entities import LoginPayloadBase
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginPayload(LoginPayloadBase):
|
||||
@field_validator("password")
|
||||
@@ -81,18 +76,14 @@ class LoginApi(Resource):
|
||||
def post(self):
|
||||
"""Authenticate user and login."""
|
||||
payload = LoginPayload.model_validate(web_ns.payload or {})
|
||||
normalized_email = payload.email.lower()
|
||||
|
||||
try:
|
||||
account = WebAppAuthService.authenticate(payload.email, payload.password)
|
||||
except services.errors.account.AccountLoginError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError()
|
||||
except services.errors.account.AccountPasswordError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.INVALID_CREDENTIALS)
|
||||
raise AuthenticationFailedError()
|
||||
except services.errors.account.AccountNotFoundError:
|
||||
_log_web_login_failure(email=normalized_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@@ -221,30 +212,21 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token_data = WebAppAuthService.get_email_code_login_data(payload.token)
|
||||
if token_data is None:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE_TOKEN)
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = token_data.get("email")
|
||||
if not isinstance(token_email, str):
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
normalized_token_email = token_email.lower()
|
||||
if normalized_token_email != user_email:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.EMAIL_CODE_EMAIL_MISMATCH)
|
||||
raise InvalidEmailError()
|
||||
|
||||
if token_data["code"] != payload.code:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.INVALID_EMAIL_CODE)
|
||||
raise EmailCodeError()
|
||||
|
||||
WebAppAuthService.revoke_email_code_login_token(payload.token)
|
||||
try:
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
except Unauthorized as exc:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_BANNED)
|
||||
raise AccountBannedError() from exc
|
||||
account = WebAppAuthService.get_user_through_email(token_email)
|
||||
if not account:
|
||||
_log_web_login_failure(email=user_email, reason=LoginFailureReason.ACCOUNT_NOT_FOUND)
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
@@ -252,12 +234,3 @@ class EmailCodeLoginApi(Resource):
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
def _log_web_login_failure(*, email: str, reason: LoginFailureReason) -> None:
|
||||
logger.warning(
|
||||
"Web login failed: email=%s reason=%s ip_address=%s",
|
||||
email,
|
||||
reason,
|
||||
extract_remote_ip(request),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
@@ -104,23 +103,21 @@ class PassportResource(Resource):
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None) -> dict[str, Any] | None:
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
"""
|
||||
Decode the enterprise user session from the Authorization header.
|
||||
"""
|
||||
if not jwt_token:
|
||||
return None
|
||||
|
||||
decoded: dict[str, Any] = PassportService().verify(jwt_token)
|
||||
decoded = PassportService().verify(jwt_token)
|
||||
source = decoded.get("token_source")
|
||||
if not source or source != "webapp_login_token":
|
||||
raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.")
|
||||
return decoded
|
||||
|
||||
|
||||
def exchange_token_for_existing_web_user(
|
||||
app_code: str, enterprise_user_decoded: dict[str, Any], auth_type: WebAppAuthType
|
||||
):
|
||||
def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict, auth_type: WebAppAuthType):
|
||||
"""
|
||||
Exchange a token for an existing web user session.
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from sqlalchemy import select
|
||||
@@ -113,12 +113,12 @@ class AppSiteInfo:
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict[str, Any]:
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict[str, Any], marshal(site, AppSiteApi.site_fields))
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict[str, Any]:
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict[str, Any], marshal(app_site_info, AppSiteApi.app_fields))
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
399
api/core/agent/agent_app_runner.py
Normal file
399
api/core/agent/agent_app_runner.py
Normal file
@@ -0,0 +1,399 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, ExecutionContext
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
|
||||
@property
|
||||
def model_features(self) -> list[ModelFeature]:
|
||||
llm_model = cast(LargeLanguageModel, self.model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(self.model_instance.model_name, self.model_instance.credentials)
|
||||
if not model_schema:
|
||||
return []
|
||||
return list(model_schema.features or [])
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
conversation_id=self.conversation.id if self.conversation else None,
|
||||
message_id=self.message.id if self.message else None,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id: str | None = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -92,3 +94,79 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
Carries IDs and metadata needed for tracing, auditing, and correlation
|
||||
but not part of the core business logic.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""Structured log entry for agent execution tracing."""
|
||||
|
||||
class LogType(StrEnum):
|
||||
ROUND = "round"
|
||||
THOUGHT = "thought"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
label: str = Field(...)
|
||||
log_type: LogType = Field(...)
|
||||
parent_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
status: LogStatus = Field(...)
|
||||
data: Mapping[str, Any] = Field(...)
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={})
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""Agent execution result."""
|
||||
|
||||
text: str = Field(default="")
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
usage: Any | None = Field(default=None)
|
||||
finish_reason: str | None = Field(default=None)
|
||||
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
506
api/core/agent/patterns/base.py
Normal file
506
api/core/agent/patterns/base.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p
|
||||
for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
358
api/core/agent/patterns/function_call.py
Normal file
358
api/core/agent/patterns/function_call.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if not isinstance(chunks, LLMResult):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None, False
|
||||
418
api/core/agent/patterns/react.py
Normal file
418
api/core/agent/patterns/react.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
finish_reason=None,
|
||||
),
|
||||
system_fingerprint=result.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
108
api/core/agent/patterns/strategy_factory.py
Normal file
108
api/core/agent/patterns/strategy_factory.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file.models import File
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class AgentStrategyEntity(BaseModel):
|
||||
identity: AgentStrategyIdentity
|
||||
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||
output_schema: dict[str, Any] | None = None
|
||||
output_schema: dict | None = None
|
||||
features: list[AgentFeature] | None = None
|
||||
meta_version: str | None = None
|
||||
# pydantic configs
|
||||
|
||||
@@ -22,8 +22,8 @@ class SensitiveWordAvoidanceConfigManager:
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
cls, tenant_id: str, config: dict, only_structure_validate: bool = False
|
||||
) -> tuple[dict, list[str]]:
|
||||
if not config.get("sensitive_word_avoidance"):
|
||||
config["sensitive_word_avoidance"] = {"enabled": False}
|
||||
|
||||
|
||||
@@ -138,9 +138,7 @@ class DatasetConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(
|
||||
cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for dataset feature
|
||||
|
||||
@@ -174,7 +172,7 @@ class DatasetConfigManager:
|
||||
return config, ["agent_mode", "dataset_configs", "dataset_query_variable"]
|
||||
|
||||
@classmethod
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict[str, Any]):
|
||||
def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mode: AppMode, config: dict):
|
||||
"""
|
||||
Extract dataset config for legacy compatibility
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class ModelConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for model config
|
||||
|
||||
@@ -108,7 +108,7 @@ class ModelConfigManager:
|
||||
return dict(config), ["model"]
|
||||
|
||||
@classmethod
|
||||
def validate_model_completion_params(cls, cp: dict[str, Any]):
|
||||
def validate_model_completion_params(cls, cp: dict):
|
||||
# model.completion_params
|
||||
if not isinstance(cp, dict):
|
||||
raise ValueError("model.completion_params must be of object type")
|
||||
|
||||
@@ -65,7 +65,7 @@ class PromptTemplateConfigManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, app_mode: AppMode, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate pre_prompt and set defaults for prompt feature
|
||||
depending on the config['model']
|
||||
@@ -130,7 +130,7 @@ class PromptTemplateConfigManager:
|
||||
return config, ["prompt_type", "pre_prompt", "chat_prompt_config", "completion_prompt_config"]
|
||||
|
||||
@classmethod
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict[str, Any]):
|
||||
def validate_post_prompt_and_set_defaults(cls, config: dict):
|
||||
"""
|
||||
Validate post_prompt and set defaults for prompt feature
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from graphon.variables.input_entities import VariableEntity, VariableEntityType
|
||||
|
||||
@@ -82,7 +82,7 @@ class BasicVariablesConfigManager:
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@@ -99,7 +99,7 @@ class BasicVariablesConfigManager:
|
||||
return config, related_config_keys
|
||||
|
||||
@classmethod
|
||||
def validate_variables_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for user input form
|
||||
|
||||
@@ -164,9 +164,7 @@ class BasicVariablesConfigManager:
|
||||
return config, ["user_input_form"]
|
||||
|
||||
@classmethod
|
||||
def validate_external_data_tools_and_set_defaults(
|
||||
cls, tenant_id: str, config: dict[str, Any]
|
||||
) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_external_data_tools_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for external data fetch feature
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class FileUploadConfigManager:
|
||||
return FileUploadConfig.model_validate(file_upload_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
|
||||
@@ -15,7 +13,7 @@ class AppConfigModel(BaseModel):
|
||||
|
||||
class MoreLikeThisConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -25,7 +23,7 @@ class MoreLikeThisConfigManager:
|
||||
return AppConfigModel.model_validate(validated_config).more_like_this.enabled
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
try:
|
||||
return AppConfigModel.model_validate(config).model_dump(), ["more_like_this"]
|
||||
except ValidationError:
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class OpeningStatementConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]) -> tuple[str, list[str]]:
|
||||
def convert(cls, config: dict) -> tuple[str, list]:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -18,7 +15,7 @@ class OpeningStatementConfigManager:
|
||||
return opening_statement, suggested_questions_list
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for opening statement feature
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrievalResourceConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
def convert(cls, config: dict) -> bool:
|
||||
show_retrieve_source = False
|
||||
retriever_resource_dict = config.get("retriever_resource")
|
||||
if retriever_resource_dict:
|
||||
@@ -13,7 +10,7 @@ class RetrievalResourceConfigManager:
|
||||
return show_retrieve_source
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for retriever resource feature
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SpeechToTextConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -18,7 +15,7 @@ class SpeechToTextConfigManager:
|
||||
return speech_to_text
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for speech to text feature
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]) -> bool:
|
||||
def convert(cls, config: dict) -> bool:
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -18,7 +15,7 @@ class SuggestedQuestionsAfterAnswerConfigManager:
|
||||
return suggested_questions_after_answer
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for suggested questions feature
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import TextToSpeechEntity
|
||||
|
||||
|
||||
class TextToSpeechConfigManager:
|
||||
@classmethod
|
||||
def convert(cls, config: dict[str, Any]):
|
||||
def convert(cls, config: dict):
|
||||
"""
|
||||
Convert model config to model config
|
||||
|
||||
@@ -24,7 +22,7 @@ class TextToSpeechConfigManager:
|
||||
return text_to_speech
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for text to speech feature
|
||||
|
||||
|
||||
@@ -177,6 +177,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# Resolve parent_message_id for thread continuity
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
parent_message_id: str | None = UUID_NIL
|
||||
else:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if not parent_message_id and conversation:
|
||||
parent_message_id = self._resolve_latest_message_id(conversation.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -188,7 +196,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=parent_message_id,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
@@ -689,3 +697,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _resolve_latest_message_id(conversation_id: str) -> str | None:
|
||||
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Message.id)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = db.session.scalar(stmt)
|
||||
return str(latest_id) if latest_id else None
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@@ -192,24 +189,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
53
api/core/app/apps/common/legacy_response_adapter.py
Normal file
53
api/core/app/apps/common/legacy_response_adapter.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Legacy Response Adapter for transparent upgrade.
|
||||
|
||||
When old apps (chat/completion/agent-chat) run through the Agent V2
|
||||
workflow engine via transparent upgrade, the SSE events are in workflow
|
||||
format (workflow_started, node_started, etc.). This adapter filters out
|
||||
workflow-specific events and passes through only the events that old
|
||||
clients expect (message, message_end, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOW_ONLY_EVENTS = frozenset({
|
||||
"workflow_started",
|
||||
"workflow_finished",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"iteration_started",
|
||||
"iteration_next",
|
||||
"iteration_completed",
|
||||
})
|
||||
|
||||
|
||||
def adapt_workflow_stream_for_legacy(
|
||||
stream: Generator[str, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Filter workflow-specific SSE events from a streaming response.
|
||||
|
||||
Passes through message, message_end, agent_log, error, ping events.
|
||||
Suppresses workflow_started, workflow_finished, node_started, node_finished.
|
||||
|
||||
This makes the SSE stream look more like what old easy-UI apps produce,
|
||||
while still carrying the actual LLM response content.
|
||||
"""
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.strip():
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
try:
|
||||
if chunk.startswith("data: "):
|
||||
data = json.loads(chunk[6:])
|
||||
event = data.get("event", "")
|
||||
if event in WORKFLOW_ONLY_EVENTS:
|
||||
continue
|
||||
yield chunk
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
yield chunk
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -17,7 +17,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
_blocking_response_type = WorkflowAppBlockingResponse
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking full response.
|
||||
:param blocking_response: blocking response
|
||||
@@ -26,7 +26,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict[str, Any]: # type: ignore[override]
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
"""
|
||||
Convert blocking simple response.
|
||||
:param blocking_response: blocking response
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||
@@ -36,9 +34,7 @@ class PipelineConfigManager(BaseAppConfigManager):
|
||||
return pipeline_config
|
||||
|
||||
@classmethod
|
||||
def config_validate(
|
||||
cls, tenant_id: str, config: dict[str, Any], only_structure_validate: bool = False
|
||||
) -> dict[str, Any]:
|
||||
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||
"""
|
||||
Validate for pipeline config
|
||||
|
||||
|
||||
@@ -782,7 +782,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: dict[str, Any] | None = None,
|
||||
next_page_parameters: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
||||
@@ -146,8 +146,6 @@ class WorkflowBasedAppRunner:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
72
api/core/app/entities/llm_generation_entities.py
Normal file
72
api/core/app/entities/llm_generation_entities.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@@ -521,7 +521,7 @@ class IterationNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@@ -547,7 +547,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
||||
title: str
|
||||
index: int
|
||||
created_at: int
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
extras: dict = Field(default_factory=dict)
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
@@ -571,7 +571,7 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, Any] | None = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
@@ -602,7 +602,7 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
||||
node_type: str
|
||||
title: str
|
||||
created_at: int
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
extras: dict = Field(default_factory=dict)
|
||||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
@@ -653,7 +653,7 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, Any] | None = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
|
||||
@@ -14,7 +14,7 @@ class DatasourceApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
parameters: list[DatasourceParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: dict[str, Any] | None = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow"] | None
|
||||
@@ -30,7 +30,7 @@ class DatasourceProviderApiEntityDict(TypedDict):
|
||||
icon: str | dict
|
||||
label: I18nObjectDict
|
||||
type: str
|
||||
team_credentials: dict[str, Any] | None
|
||||
team_credentials: dict | None
|
||||
is_team_authorization: bool
|
||||
allow_delete: bool
|
||||
datasources: list[Any]
|
||||
@@ -45,8 +45,8 @@ class DatasourceProviderApiEntity(BaseModel):
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: dict[str, Any] | None = None
|
||||
original_credentials: dict[str, Any] | None = None
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||
|
||||
@@ -129,7 +129,7 @@ class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: dict[str, Any] | None = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
@@ -192,7 +192,7 @@ class DatasourceInvokeMeta(BaseModel):
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: str | None = None
|
||||
tool_config: dict[str, Any] | None = None
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> DatasourceInvokeMeta:
|
||||
@@ -242,7 +242,7 @@ class OnlineDocumentPage(BaseModel):
|
||||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: dict[str, Any] | None = Field(None, description="The page icon")
|
||||
page_icon: dict | None = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: str | None = Field(None, description="The parent page id")
|
||||
@@ -301,7 +301,7 @@ class GetWebsiteCrawlRequest(BaseModel):
|
||||
Get website crawl request
|
||||
"""
|
||||
|
||||
crawl_parameters: dict[str, Any] = Field(..., description="The crawl parameters")
|
||||
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||
|
||||
|
||||
class WebSiteInfoDetail(BaseModel):
|
||||
@@ -358,7 +358,7 @@ class OnlineDriveFileBucket(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
@@ -369,7 +369,7 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: dict[str, Any] | None = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
@@ -39,7 +37,7 @@ class PipelineDocument(BaseModel):
|
||||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: dict[str, Any] | None = None
|
||||
data_source_info: dict | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: str | None = None
|
||||
|
||||
@@ -6,7 +6,6 @@ import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@@ -112,7 +111,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return ModelProviderFactory(model_runtime=self._bound_model_runtime)
|
||||
return create_plugin_model_provider_factory(tenant_id=self.tenant_id)
|
||||
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict[str, Any] | None:
|
||||
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
||||
"""
|
||||
Get current credentials.
|
||||
|
||||
@@ -234,7 +233,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict[str, Any] | None:
|
||||
def _get_specific_provider_credential(self, credential_id: str) -> dict | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@@ -298,7 +297,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict[str, Any] | None:
|
||||
def get_provider_credential(self, credential_id: str | None = None) -> dict | None:
|
||||
"""
|
||||
Get provider credentials.
|
||||
|
||||
@@ -318,9 +317,7 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -450,7 +447,7 @@ class ProviderConfiguration(BaseModel):
|
||||
provider_names.append(model_provider_id.provider_name)
|
||||
return provider_names
|
||||
|
||||
def create_provider_credential(self, credentials: dict[str, Any], credential_name: str | None):
|
||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||
"""
|
||||
Add custom provider credentials.
|
||||
:param credentials: provider credentials
|
||||
@@ -518,7 +515,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def update_provider_credential(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
credentials: dict,
|
||||
credential_id: str,
|
||||
credential_name: str | None,
|
||||
):
|
||||
@@ -763,7 +760,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
def _get_specific_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str
|
||||
) -> dict[str, Any] | None:
|
||||
) -> dict | None:
|
||||
"""
|
||||
Get a specific provider credential by ID.
|
||||
:param credential_id: Credential ID
|
||||
@@ -835,9 +832,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
||||
return session.execute(stmt).scalar_one_or_none() is not None
|
||||
|
||||
def get_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credential_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
||||
"""
|
||||
Get custom model credentials.
|
||||
|
||||
@@ -877,7 +872,7 @@ class ProviderConfiguration(BaseModel):
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credentials: dict,
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
@@ -944,7 +939,7 @@ class ProviderConfiguration(BaseModel):
|
||||
return _validate(new_session)
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Create a custom model credential.
|
||||
@@ -1007,12 +1002,7 @@ class ProviderConfiguration(BaseModel):
|
||||
raise
|
||||
|
||||
def update_custom_model_credential(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_name: str | None,
|
||||
credential_id: str,
|
||||
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Update a custom model credential.
|
||||
@@ -1422,9 +1412,7 @@ class ProviderConfiguration(BaseModel):
|
||||
# Get model instance of LLM
|
||||
return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type)
|
||||
|
||||
def get_model_schema(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any] | None
|
||||
) -> AIModelEntity | None:
|
||||
def get_model_schema(self, model_type: ModelType, model: str, credentials: dict | None) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
@@ -1483,7 +1471,7 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
return secret_input_form_variables
|
||||
|
||||
def obfuscated_credentials(self, credentials: dict[str, Any], credential_form_schemas: list[CredentialFormSchema]):
|
||||
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
||||
"""
|
||||
Obfuscated credentials.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -88,7 +88,7 @@ class SystemConfiguration(BaseModel):
|
||||
enabled: bool
|
||||
current_quota_type: ProviderQuotaType | None = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
credentials: dict[str, Any] | None = None
|
||||
credentials: dict | None = None
|
||||
|
||||
|
||||
class CustomProviderConfiguration(BaseModel):
|
||||
@@ -96,7 +96,7 @@ class CustomProviderConfiguration(BaseModel):
|
||||
Model class for provider custom configuration.
|
||||
"""
|
||||
|
||||
credentials: dict[str, Any]
|
||||
credentials: dict
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
@@ -109,7 +109,7 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict[str, Any] | None
|
||||
credentials: dict | None
|
||||
current_credential_id: str | None = None
|
||||
current_credential_name: str | None = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
@@ -145,7 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict[str, Any]
|
||||
credentials: dict
|
||||
credential_source_type: str | None = None
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, cast
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -14,7 +14,7 @@ class APIBasedExtensionRequestor:
|
||||
self.api_endpoint = api_endpoint
|
||||
self.api_key = api_key
|
||||
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict[str, Any]) -> dict[str, Any]:
|
||||
def request(self, point: APIBasedExtensionPoint, params: dict):
|
||||
"""
|
||||
Request the api.
|
||||
|
||||
@@ -49,4 +49,4 @@ class APIBasedExtensionRequestor:
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"request error, status_code: {response.status_code}, content: {response.text[:100]}")
|
||||
|
||||
return cast(dict[str, Any], response.json())
|
||||
return cast(dict, response.json())
|
||||
|
||||
@@ -21,8 +21,8 @@ class ExtensionModule(StrEnum):
|
||||
class ModuleExtension(BaseModel):
|
||||
extension_class: Any | None = None
|
||||
name: str
|
||||
label: dict[str, Any] | None = None
|
||||
form_schema: list[dict[str, Any]] | None = None
|
||||
label: dict | None = None
|
||||
form_schema: list | None = None
|
||||
builtin: bool = True
|
||||
position: int | None = None
|
||||
|
||||
|
||||
@@ -6,14 +6,14 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class ExternalDataToolFactory:
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict[str, Any]):
|
||||
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict):
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
self.__extension_instance = extension_class(
|
||||
tenant_id=tenant_id, app_id=app_id, variable=variable, config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict[str, Any]) -> None:
|
||||
def validate_config(cls, name: str, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
|
||||
75
api/core/helper/creators.py
Normal file
75
api/core/helper/creators.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
"""Upload a DSL file to the Creators Platform anonymous upload endpoint.
|
||||
|
||||
Args:
|
||||
dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP).
|
||||
filename: Original filename for the upload.
|
||||
|
||||
Returns:
|
||||
The claim_code string used to retrieve the DSL later.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the upload request fails.
|
||||
ValueError: If the response does not contain a valid claim_code.
|
||||
"""
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
"""Generate the redirect URL to the Creators Platform frontend.
|
||||
|
||||
Redirects to the Creators Platform root page with the dsl_claim_code.
|
||||
If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud),
|
||||
also signs an OAuth authorization code so the frontend can
|
||||
automatically authenticate the user via the OAuth callback.
|
||||
|
||||
For self-hosted Dify without OAuth client_id configured, only the
|
||||
dsl_claim_code is passed and the user must log in manually.
|
||||
|
||||
Args:
|
||||
user_account_id: The Dify user account ID.
|
||||
claim_code: The claim_code obtained from upload_dsl().
|
||||
|
||||
Returns:
|
||||
The full redirect URL string.
|
||||
"""
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@@ -16,7 +15,7 @@ class ProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
def get(self) -> dict | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@@ -34,7 +33,7 @@ class ProviderCredentialsCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict[str, Any]):
|
||||
def set(self, credentials: dict):
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class ProviderCredentialsCache(ABC):
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
def get(self) -> dict | None:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
@@ -71,7 +71,7 @@ class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
def get(self) -> dict | None:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
from enum import StrEnum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@@ -19,7 +18,7 @@ class ToolParameterCache:
|
||||
f":identity_id:{identity_id}"
|
||||
)
|
||||
|
||||
def get(self) -> dict[str, Any] | None:
|
||||
def get(self) -> dict | None:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
@@ -37,7 +36,7 @@ class ToolParameterCache:
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict[str, Any]):
|
||||
def set(self, parameters: dict):
|
||||
"""Cache model provider credentials."""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
|
||||
@@ -735,9 +735,7 @@ class IndexingRunner:
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
document_id: str,
|
||||
after_indexing_status: IndexingStatus,
|
||||
extra_update_params: Mapping[Any, Any] | None = None,
|
||||
document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None
|
||||
):
|
||||
"""
|
||||
Update the document indexing status.
|
||||
@@ -764,7 +762,7 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: Mapping[Any, Any]):
|
||||
def _update_segments_by_document(dataset_document_id: str, update_params: dict):
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
|
||||
62
api/core/llm_generator/context_models.py
Normal file
62
api/core/llm_generator/context_models.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol, TypedDict, cast
|
||||
from typing import Protocol, TypedDict, cast
|
||||
|
||||
import json_repair
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey
|
||||
@@ -533,7 +533,7 @@ class LLMGenerator:
|
||||
def __instruction_modify_common(
|
||||
tenant_id: str,
|
||||
model_config: ModelConfig,
|
||||
last_run: dict[str, Any] | None,
|
||||
last_run: dict | None,
|
||||
current: str | None,
|
||||
error_message: str | None,
|
||||
instruction: str,
|
||||
|
||||
67
api/core/llm_generator/output_models.py
Normal file
67
api/core/llm_generator/output_models.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
"""Output model for suggested questions generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
description="Exactly 3 suggested follow-up questions for the user",
|
||||
)
|
||||
|
||||
|
||||
class VariableSelectorOutput(BaseModel):
|
||||
"""Variable selector mapping code variable to upstream node output.
|
||||
|
||||
Note: Separate from VariableSelector to ensure 'additionalProperties: false'
|
||||
in JSON schema for OpenAI/Azure strict mode.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(description="Variable name used in the generated code")
|
||||
value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeNodeOutputItem(BaseModel):
|
||||
"""Single output variable definition.
|
||||
|
||||
Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and
|
||||
does not support dynamic object keys, so outputs use array format.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(description="Output variable name returned by the main function")
|
||||
type: SegmentType = Field(description="Data type of the output variable")
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
"""Structured output for code node generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelectorOutput] = Field(
|
||||
description="Input variables mapping code variables to upstream node outputs"
|
||||
)
|
||||
code: str = Field(description="Generated code with a main function that processes inputs and returns outputs")
|
||||
outputs: list[CodeNodeOutputItem] = Field(
|
||||
description="Output variable definitions specifying name and type for each return value"
|
||||
)
|
||||
message: str = Field(description="Brief explanation of what the generated code does")
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
"""Output model for instruction-based prompt modification."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str = Field(description="The modified prompt content after applying the instruction")
|
||||
message: str = Field(description="Brief explanation of what changes were made")
|
||||
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
File path detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect sandbox file path fields in JSON Schema (format: "file-path")
|
||||
2. Adapt schemas to add file-path descriptions before model invocation
|
||||
3. Convert sandbox file path strings into File objects via a resolver
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import File
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
FILE_PATH_FORMAT = "file-path"
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox"
|
||||
|
||||
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
"""Check if a schema property represents a sandbox file path."""
|
||||
if schema.get("type") != "string":
|
||||
return False
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format == FILE_PATH_FORMAT
|
||||
|
||||
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""Recursively detect file path fields in a JSON schema."""
|
||||
file_path_fields: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
|
||||
return file_path_fields
|
||||
|
||||
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Normalize sandbox file path fields and collect their JSON paths."""
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
"""Convert sandbox file paths into File objects using the resolver."""
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
|
||||
return result_dict, files
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema["format"] = FILE_PATH_FORMAT
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
if FILE_PATH_DESCRIPTION_SUFFIX not in description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if remaining:
|
||||
for item in target_list:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
else:
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
@@ -200,9 +200,9 @@ def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
structured_output_schema: Mapping,
|
||||
model_parameters: dict[str, Any],
|
||||
model_parameters: dict,
|
||||
rules: list[ParameterRule],
|
||||
) -> dict[str, Any]:
|
||||
):
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
|
||||
@@ -224,7 +224,7 @@ def _handle_native_json_schema(
|
||||
return model_parameters
|
||||
|
||||
|
||||
def _set_response_format(model_parameters: dict[str, Any], rules: list[ParameterRule]) -> None:
|
||||
def _set_response_format(model_parameters: dict, rules: list):
|
||||
"""
|
||||
Set the appropriate response format parameter based on model rules.
|
||||
|
||||
@@ -326,7 +326,7 @@ def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema
|
||||
return {"schema": processed_schema, "name": "llm_response"}
|
||||
|
||||
|
||||
def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
def remove_additional_properties(schema: dict):
|
||||
"""
|
||||
Remove additionalProperties fields from JSON schema.
|
||||
Used for models like Gemini that don't support this property.
|
||||
@@ -349,7 +349,7 @@ def remove_additional_properties(schema: dict[str, Any]) -> None:
|
||||
remove_additional_properties(item)
|
||||
|
||||
|
||||
def convert_boolean_to_string(schema: dict[str, Any]) -> None:
|
||||
def convert_boolean_to_string(schema: dict):
|
||||
"""
|
||||
Convert boolean type specifications to string in JSON schema.
|
||||
|
||||
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
11
api/core/memory/__init__.py
Normal file
11
api/core/memory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
82
api/core/memory/base.py
Normal file
82
api/core/memory/base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from graphon.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
196
api/core/memory/node_token_buffer_memory.py
Normal file
196
api/core/memory/node_token_buffer_memory.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@@ -64,7 +64,7 @@ class TokenBufferMemory:
|
||||
match self.conversation.mode:
|
||||
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW | AppMode.AGENT:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
@@ -77,7 +77,7 @@ class ModelInstance:
|
||||
|
||||
@staticmethod
|
||||
def _get_load_balancing_manager(
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict[str, Any]
|
||||
configuration: ProviderConfiguration, model_type: ModelType, model: str, credentials: dict
|
||||
) -> Optional["LBModelManager"]:
|
||||
"""
|
||||
Get load balancing model credentials
|
||||
@@ -115,7 +115,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[True] = True,
|
||||
@@ -126,7 +126,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
@@ -137,7 +137,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
@@ -147,7 +147,7 @@ class ModelInstance:
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict[str, Any] | None = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
@@ -528,7 +528,7 @@ class LBModelManager:
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
load_balancing_configs: list[ModelLoadBalancingConfiguration],
|
||||
managed_credentials: dict[str, Any] | None = None,
|
||||
managed_credentials: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Load balancing model manager
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -12,7 +10,7 @@ from models.api_based_extension import APIBasedExtension
|
||||
|
||||
class ModerationInputParams(BaseModel):
|
||||
app_id: str = ""
|
||||
inputs: dict[str, Any] = Field(default_factory=dict)
|
||||
inputs: dict = Field(default_factory=dict)
|
||||
query: str = ""
|
||||
|
||||
|
||||
@@ -25,7 +23,7 @@ class ApiModeration(Moderation):
|
||||
name: str = "api"
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls, tenant_id: str, config: dict[str, Any]):
|
||||
def validate_config(cls, tenant_id: str, config: dict):
|
||||
"""
|
||||
Validate the incoming form config data.
|
||||
|
||||
@@ -43,7 +41,7 @@ class ApiModeration(Moderation):
|
||||
if not extension:
|
||||
raise ValueError("API-based Extension not found. Please check it again.")
|
||||
|
||||
def moderation_for_inputs(self, inputs: dict[str, Any], query: str = "") -> ModerationInputsResult:
|
||||
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||
flagged = False
|
||||
preset_response = ""
|
||||
if self.config is None:
|
||||
@@ -75,7 +73,7 @@ class ApiModeration(Moderation):
|
||||
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
|
||||
)
|
||||
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict[str, Any]):
|
||||
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict):
|
||||
if self.config is None:
|
||||
raise ValueError("The config is not set.")
|
||||
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id", ""))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user