mirror of
https://github.com/langgenius/dify.git
synced 2026-04-11 11:02:43 +00:00
Compare commits
22 Commits
feat/new-a
...
feat/per-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e93fdddb20 | ||
|
|
ccfc8c6f15 | ||
|
|
4fb3fab82d | ||
|
|
3cea0dfb07 | ||
|
|
0d6db3a3f3 | ||
|
|
3d5a81bd30 | ||
|
|
208604a3a8 | ||
|
|
63bfba0bdb | ||
|
|
9948a51b14 | ||
|
|
0e0bb3582f | ||
|
|
546062d2cd | ||
|
|
aad0b3c157 | ||
|
|
4d4265f531 | ||
|
|
e138523123 | ||
|
|
a65e1f71b4 | ||
|
|
909c062ee1 | ||
|
|
f5322e45fc | ||
|
|
017f09f1e9 | ||
|
|
0ba66ab155 | ||
|
|
5cd267d755 | ||
|
|
d30946dabf | ||
|
|
b0e524213e |
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@@ -1,3 +1,10 @@
|
||||
web:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'web/**'
|
||||
- any-glob-to-any-file:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
|
||||
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
82
.github/scripts/generate-i18n-changes.mjs
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
const outputPath = process.env.I18N_CHANGES_OUTPUT_PATH || '/tmp/i18n-changes.json'
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
outputPath,
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -39,9 +39,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
|
||||
2
.github/workflows/docker-build.yml
vendored
2
.github/workflows/docker-build.yml
vendored
@@ -8,9 +8,11 @@ on:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- packages/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
|
||||
4
.github/workflows/main-ci.yml
vendored
4
.github/workflows/main-ci.yml
vendored
@@ -65,9 +65,11 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
@@ -77,9 +79,11 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'packages/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.npmrc'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -77,9 +77,11 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
packages/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.npmrc
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
1
.github/workflows/tool-test-sdks.yaml
vendored
1
.github/workflows/tool-test-sdks.yaml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .npmrc
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
101
.github/workflows/translate-i18n-claude.yml
vendored
101
.github/workflows/translate-i18n-claude.yml
vendored
@@ -68,89 +68,7 @@ jobs:
|
||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||
|
||||
generate_changes_json() {
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
}
|
||||
|
||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||
@@ -270,7 +188,7 @@ jobs:
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `pnpm`.
|
||||
- Use Bash only for `vp`.
|
||||
- Do not use Bash for `git`, `gh`, or branch management.
|
||||
|
||||
Required execution plan:
|
||||
@@ -292,7 +210,7 @@ jobs:
|
||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
@@ -300,19 +218,19 @@ jobs:
|
||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- DELETE removed keys. Prefer `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
6. Verify only the edited files.
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Run `vp run dify-web#lint:fix --quiet -- <relative edited i18n file paths under web/>`
|
||||
- Run `vp run dify-web#i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Stop after the scoped locale files are updated and verification passes.
|
||||
- Do not create branches, commits, or pull requests.
|
||||
claude_args: |
|
||||
--max-turns 120
|
||||
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
|
||||
--allowedTools "Read,Write,Edit,Bash(vp *),Bash(vp:*),Glob,Grep"
|
||||
|
||||
- name: Prepare branch metadata
|
||||
id: pr_meta
|
||||
@@ -354,6 +272,7 @@ jobs:
|
||||
- name: Create or update translation PR
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||
@@ -402,8 +321,8 @@ jobs:
|
||||
'',
|
||||
'## Verification',
|
||||
'',
|
||||
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
|
||||
`- \`vp run dify-web#i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`vp run dify-web#lint:fix --quiet -- <edited i18n files under web/>\``,
|
||||
'',
|
||||
'## Notes',
|
||||
'',
|
||||
|
||||
83
.github/workflows/trigger-i18n-sync.yml
vendored
83
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -42,88 +42,7 @@ jobs:
|
||||
fi
|
||||
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = readCurrentJson(fileStem) || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: readCurrentJson(fileStem) === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
node .github/scripts/generate-i18n-changes.mjs
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
@@ -81,8 +81,8 @@ if $web_modified; then
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
if ! pnpm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'pnpm run type-check:tsgo' to fix the errors."
|
||||
if ! npm run type-check:tsgo; then
|
||||
echo "Type check failed. Please run 'npm run type-check:tsgo' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
@@ -90,8 +90,8 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! pnpm run knip; then
|
||||
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
||||
if ! npm run knip; then
|
||||
echo "Knip check failed. Please run 'npm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -271,27 +271,6 @@ class PluginConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CliApiConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
|
||||
"""
|
||||
|
||||
CLI_API_URL: str = Field(
|
||||
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
|
||||
default="http://localhost:5001",
|
||||
)
|
||||
|
||||
SANDBOX_DIFY_CLI_ROOT: str = Field(
|
||||
description="Root directory containing dify-cli binaries (dify-cli-{os}-{arch}).",
|
||||
default="",
|
||||
)
|
||||
|
||||
DIFY_PORT: int = Field(
|
||||
description="Dify API port, used by Docker sandbox for socat forwarding.",
|
||||
default=5001,
|
||||
)
|
||||
|
||||
|
||||
class MarketplaceConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for marketplace
|
||||
@@ -308,27 +287,6 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@@ -383,15 +341,6 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@@ -1325,29 +1274,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class AgentV2UpgradeConfig(BaseSettings):
|
||||
"""Feature flags for transparent Agent V2 upgrade."""
|
||||
|
||||
AGENT_V2_TRANSPARENT_UPGRADE: bool = Field(
|
||||
description="Transparently run old apps (chat/completion/agent-chat) through the Agent V2 workflow engine. "
|
||||
"When enabled, old apps synthesize a virtual workflow at runtime instead of using legacy runners.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_V2_REPLACES_LLM: bool = Field(
|
||||
description="Transparently replace LLM nodes in workflows with Agent V2 nodes at runtime. "
|
||||
"LLMNodeData is remapped to AgentV2NodeData with tools=[] (identical behavior).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1449,9 +1375,7 @@ class FeatureConfig(
|
||||
TriggerConfig,
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
CliApiConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@@ -1475,8 +1399,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
AgentV2UpgradeConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -134,6 +134,26 @@ class DatabaseConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
DIFY_DB_USER: str | None = Field(
|
||||
description="Override for DB_USERNAME. Takes precedence when set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
DIFY_DB_PASS: str | None = Field(
|
||||
description="Override for DB_PASSWORD. Takes precedence when set.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def effective_db_username(self) -> str:
|
||||
return self.DIFY_DB_USER if self.DIFY_DB_USER is not None else self.DB_USERNAME
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def effective_db_password(self) -> str:
|
||||
return self.DIFY_DB_PASS if self.DIFY_DB_PASS is not None else self.DB_PASSWORD
|
||||
|
||||
DB_DATABASE: str = Field(
|
||||
description="Name of the database to connect to.",
|
||||
default="dify",
|
||||
@@ -163,7 +183,7 @@ class DatabaseConfig(BaseSettings):
|
||||
db_extras = f"?{db_extras}" if db_extras else ""
|
||||
return (
|
||||
f"{self.SQLALCHEMY_DATABASE_URI_SCHEME}://"
|
||||
f"{quote_plus(self.DB_USERNAME)}:{quote_plus(self.DB_PASSWORD)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{quote_plus(self.effective_db_username)}:{quote_plus(self.effective_db_password)}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_DATABASE}"
|
||||
f"{db_extras}"
|
||||
)
|
||||
|
||||
|
||||
@@ -81,20 +81,4 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new agent backed by single-node workflow)
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="CLI API",
|
||||
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
|
||||
)
|
||||
|
||||
# Create namespace
|
||||
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
|
||||
|
||||
from .dify_cli import cli_api as _plugin
|
||||
|
||||
api.add_namespace(cli_api_ns)
|
||||
|
||||
__all__ = [
|
||||
"_plugin",
|
||||
"api",
|
||||
"bp",
|
||||
"cli_api_ns",
|
||||
]
|
||||
@@ -1,190 +0,0 @@
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.cli_api import cli_api_ns
|
||||
from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data
|
||||
from controllers.cli_api.wraps import cli_api_only
|
||||
from controllers.console.wraps import setup_required
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||
from core.plugin.entities.request import (
|
||||
RequestInvokeApp,
|
||||
RequestInvokeLLM,
|
||||
RequestInvokeTool,
|
||||
RequestRequestUploadFile,
|
||||
)
|
||||
from core.sandbox.bash.dify_cli import DifyCliToolConfig
|
||||
from core.session.cli_api import CliContext
|
||||
from core.skill.entities import ToolInvocationRequest
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from graphon.file.helpers import get_signed_file_url
|
||||
from libs.helper import length_prefixed_response
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Tenant
|
||||
|
||||
|
||||
class FetchToolItem(BaseModel):
|
||||
tool_type: str
|
||||
tool_provider: str
|
||||
tool_name: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class FetchToolBatchRequest(BaseModel):
|
||||
tools: list[FetchToolItem]
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/llm")
|
||||
class CliInvokeLLMApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeLLM)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeLLM,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
def generator():
|
||||
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/tool")
|
||||
class CliInvokeToolApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeTool)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeTool,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tool_type = ToolProviderType.value_of(payload.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}")
|
||||
|
||||
def generator():
|
||||
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||
PluginToolBackwardsInvocation.invoke_tool(
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
tool_type=tool_type,
|
||||
provider=payload.provider,
|
||||
tool_name=payload.tool,
|
||||
tool_parameters=payload.tool_parameters,
|
||||
credential_id=payload.credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, generator())
|
||||
|
||||
|
||||
@cli_api_ns.route("/invoke/app")
|
||||
class CliInvokeAppApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestInvokeApp)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestInvokeApp,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
response = PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id=payload.app_id,
|
||||
user_id=user_model.id,
|
||||
tenant_id=tenant_model.id,
|
||||
conversation_id=payload.conversation_id,
|
||||
query=payload.query,
|
||||
stream=payload.response_mode == "streaming",
|
||||
inputs=payload.inputs,
|
||||
files=payload.files,
|
||||
)
|
||||
|
||||
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||
|
||||
|
||||
@cli_api_ns.route("/upload/file/request")
|
||||
class CliUploadFileRequestApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: RequestRequestUploadFile,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
url = get_signed_file_url(
|
||||
upload_file_id=f"{tenant_model.id}_{user_model.id}_{payload.filename}",
|
||||
tenant_id=tenant_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
||||
@cli_api_ns.route("/fetch/tools/batch")
|
||||
class CliFetchToolsBatchApi(Resource):
|
||||
@cli_api_only
|
||||
@get_cli_user_tenant
|
||||
@setup_required
|
||||
@plugin_data(payload_type=FetchToolBatchRequest)
|
||||
def post(
|
||||
self,
|
||||
user_model: Account | EndUser,
|
||||
tenant_model: Tenant,
|
||||
payload: FetchToolBatchRequest,
|
||||
cli_context: CliContext,
|
||||
):
|
||||
tools: list[dict] = []
|
||||
|
||||
for item in payload.tools:
|
||||
provider_type = ToolProviderType.value_of(item.tool_type)
|
||||
|
||||
request = ToolInvocationRequest(
|
||||
tool_type=provider_type,
|
||||
provider=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
|
||||
abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}")
|
||||
|
||||
try:
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
tenant_id=tenant_model.id,
|
||||
provider_type=provider_type,
|
||||
provider_id=item.tool_provider,
|
||||
tool_name=item.tool_name,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
credential_id=item.credential_id,
|
||||
)
|
||||
tool_config = DifyCliToolConfig.create_from_tool(tool_runtime)
|
||||
tools.append(tool_config.model_dump())
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump()
|
||||
@@ -1,137 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, g, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.session.cli_api import CliApiSession, CliContext
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
"""
|
||||
Get current user
|
||||
|
||||
NOTE: user_id is not trusted, it could be maliciously set to any value.
|
||||
As a result, it could only be considered as an end user id.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
type="service_api",
|
||||
is_anonymous=is_anonymous,
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
raise ValueError("user not found")
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
def get_cli_user_tenant(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
session: CliApiSession | None = getattr(g, "cli_api_session", None)
|
||||
if session is None:
|
||||
raise ValueError("session not found")
|
||||
|
||||
user_id = session.user_id
|
||||
tenant_id = session.tenant_id
|
||||
cli_context = CliContext.model_validate(session.context)
|
||||
|
||||
if not user_id:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
|
||||
try:
|
||||
tenant_model = (
|
||||
db.session.query(Tenant)
|
||||
.where(
|
||||
Tenant.id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
if not tenant_model:
|
||||
raise ValueError("tenant not found")
|
||||
|
||||
kwargs["tenant_model"] = tenant_model
|
||||
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||
kwargs["cli_context"] = cli_context
|
||||
|
||||
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
raise ValueError("invalid json")
|
||||
|
||||
try:
|
||||
payload = payload_type.model_validate(data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"invalid payload: {str(e)}")
|
||||
|
||||
kwargs["payload"] = payload
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
@@ -1,56 +0,0 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, g, request
|
||||
|
||||
from core.session.cli_api import CliApiSessionManager
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
SIGNATURE_TTL_SECONDS = 300
|
||||
|
||||
|
||||
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
|
||||
expected = hmac.new(
|
||||
session_secret.encode(),
|
||||
f"{timestamp}.".encode() + body,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
||||
|
||||
def cli_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
session_id = request.headers.get("X-Cli-Api-Session-Id")
|
||||
timestamp = request.headers.get("X-Cli-Api-Timestamp")
|
||||
signature = request.headers.get("X-Cli-Api-Signature")
|
||||
|
||||
if not session_id or not timestamp or not signature:
|
||||
abort(401)
|
||||
|
||||
try:
|
||||
ts = int(timestamp)
|
||||
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
|
||||
abort(401)
|
||||
except ValueError:
|
||||
abort(401)
|
||||
|
||||
session = CliApiSessionManager().get(session_id)
|
||||
if not session:
|
||||
abort(401)
|
||||
|
||||
body = request.get_data()
|
||||
if not _verify_signature(session.secret, timestamp, body, signature):
|
||||
abort(401)
|
||||
|
||||
g.cli_api_session = session
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
63
api/controllers/common/controller_schemas.py
Normal file
63
api/controllers/common/controller_schemas.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
|
||||
# --- Conversation schemas ---
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
# --- Message schemas ---
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
# --- Saved message schemas ---
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
# --- Workflow schemas ---
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# --- Audio schemas ---
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
@@ -41,7 +41,6 @@ from . import (
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
sandbox_files,
|
||||
setup,
|
||||
spec,
|
||||
version,
|
||||
@@ -53,7 +52,6 @@ from .app import (
|
||||
agent,
|
||||
annotation,
|
||||
app,
|
||||
app_asset,
|
||||
audio,
|
||||
completion,
|
||||
conversation,
|
||||
@@ -64,7 +62,6 @@ from .app import (
|
||||
model_config,
|
||||
ops_trace,
|
||||
site,
|
||||
skills,
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
@@ -133,7 +130,6 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
sandbox_providers,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
|
||||
@@ -51,7 +51,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@@ -61,7 +61,7 @@ _logger = logging.getLogger(__name__)
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@@ -93,9 +93,7 @@ class AppListQuery(BaseModel):
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
AppAssetNodeNotFoundError,
|
||||
AppAssetPathConflictError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.app.entities.app_asset_entities import BatchUploadNode
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.app_asset_service import AppAssetService
|
||||
from services.errors.app_asset import (
|
||||
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetParentNotFoundError,
|
||||
)
|
||||
from services.errors.app_asset import (
|
||||
AppAssetPathConflictError as ServicePathConflictError,
|
||||
)
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class CreateFolderPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class CreateFilePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class GetUploadUrlPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
size: int = Field(..., ge=0)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def strip_name(cls, v: str) -> str:
|
||||
return v.strip() if isinstance(v, str) else v
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class BatchUploadPayload(BaseModel):
|
||||
children: list[BatchUploadNode] = Field(..., min_length=1)
|
||||
parent_id: str | None = None
|
||||
|
||||
@field_validator("parent_id", mode="before")
|
||||
@classmethod
|
||||
def empty_to_none(cls, v: str | None) -> str | None:
|
||||
return v or None
|
||||
|
||||
|
||||
class UpdateFileContentPayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class RenameNodePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
|
||||
|
||||
class MoveNodePayload(BaseModel):
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class ReorderNodePayload(BaseModel):
|
||||
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]) -> None:
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
|
||||
reg(CreateFolderPayload)
|
||||
reg(CreateFilePayload)
|
||||
reg(GetUploadUrlPayload)
|
||||
reg(BatchUploadNode)
|
||||
reg(BatchUploadPayload)
|
||||
reg(UpdateFileContentPayload)
|
||||
reg(RenameNodePayload)
|
||||
reg(MoveNodePayload)
|
||||
reg(ReorderNodePayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/tree")
|
||||
class AppAssetTreeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
|
||||
return {"children": [view.model_dump() for view in tree.transform()]}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/folders")
|
||||
class AppAssetFolderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
|
||||
return node.model_dump(), 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
|
||||
class AppAssetFileDetailResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
|
||||
return {"content": content.decode("utf-8", errors="replace")}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def put(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
file = request.files.get("file")
|
||||
if file:
|
||||
content = file.read()
|
||||
else:
|
||||
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
|
||||
content = payload.content.encode("utf-8")
|
||||
|
||||
try:
|
||||
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
|
||||
class AppAssetNodeResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def delete(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
AppAssetService.delete_node(app_model, current_user.id, node_id)
|
||||
return {"result": "success"}, 200
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
|
||||
class AppAssetNodeRenameResource(Resource):
|
||||
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = RenameNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
|
||||
class AppAssetNodeMoveResource(Resource):
|
||||
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = MoveNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
|
||||
class AppAssetNodeReorderResource(Resource):
|
||||
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
|
||||
return node.model_dump()
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
|
||||
class AppAssetFileDownloadUrlResource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def get(self, app_model: App, node_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
try:
|
||||
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
|
||||
return {"download_url": download_url}
|
||||
except ServiceNodeNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
|
||||
class AppAssetFileUploadUrlResource(Resource):
|
||||
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
node, upload_url = AppAssetService.get_file_upload_url(
|
||||
app_model, current_user.id, payload.name, payload.size, payload.parent_id
|
||||
)
|
||||
return {"node": node.model_dump(), "upload_url": upload_url}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
|
||||
class AppAssetBatchUploadResource(Resource):
|
||||
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Create nodes from tree structure and return upload URLs.
|
||||
|
||||
Input:
|
||||
{
|
||||
"parent_id": "optional-target-folder-id",
|
||||
"children": [
|
||||
{"name": "folder1", "node_type": "folder", "children": [
|
||||
{"name": "file1.txt", "node_type": "file", "size": 1024}
|
||||
]},
|
||||
{"name": "root.txt", "node_type": "file", "size": 512}
|
||||
]
|
||||
}
|
||||
|
||||
Output:
|
||||
{
|
||||
"children": [
|
||||
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
|
||||
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
|
||||
]},
|
||||
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
|
||||
]
|
||||
}
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
result_children = AppAssetService.batch_create_from_tree(
|
||||
app_model,
|
||||
current_user.id,
|
||||
payload.children,
|
||||
parent_id=payload.parent_id,
|
||||
)
|
||||
return {"children": [child.model_dump() for child in result_children]}, 201
|
||||
except AppAssetParentNotFoundError:
|
||||
raise AppAssetNodeNotFoundError()
|
||||
except ServicePathConflictError:
|
||||
raise AppAssetPathConflictError()
|
||||
@@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
@@ -215,7 +215,7 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -121,21 +121,3 @@ class NeedAddIdsError(BaseHTTPException):
|
||||
error_code = "need_add_ids"
|
||||
description = "Need to add ids."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetNodeNotFoundError(BaseHTTPException):
|
||||
error_code = "app_asset_node_not_found"
|
||||
description = "App asset node not found."
|
||||
code = 404
|
||||
|
||||
|
||||
class AppAssetFileRequiredError(BaseHTTPException):
|
||||
error_code = "app_asset_file_required"
|
||||
description = "File is required."
|
||||
code = 400
|
||||
|
||||
|
||||
class AppAssetPathConflictError(BaseHTTPException):
|
||||
error_code = "app_asset_path_conflict"
|
||||
description = "Path already exists."
|
||||
code = 409
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import exists, func, select
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload as _MessageFeedbackPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -59,10 +60,8 @@ class ChatMessagesQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
class MessageFeedbackPayload(_MessageFeedbackPayloadBase):
|
||||
message_id: str = Field(..., description="Message ID")
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
@@ -238,7 +237,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
@@ -394,7 +393,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required
|
||||
from libs.login import login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from services.skill_service import SkillService
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/llm/skills")
|
||||
class NodeSkillsApi(Resource):
|
||||
"""Extract tool dependencies from an LLM node's skill prompts.
|
||||
|
||||
The client sends the full node ``data`` object in the request body.
|
||||
The server real-time builds a ``SkillBundle`` from the current draft
|
||||
``.md`` assets and resolves transitive tool dependencies — no cached
|
||||
bundle is used.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
node_data = request.get_json(force=True)
|
||||
if not isinstance(node_data, dict):
|
||||
return {"tool_dependencies": []}
|
||||
|
||||
tool_deps = SkillService.extract_tool_dependencies(
|
||||
app=app_model,
|
||||
node_data=node_data,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return {"tool_dependencies": [d.model_dump() for d in tool_deps]}
|
||||
@@ -221,7 +221,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -241,7 +241,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@@ -325,7 +325,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -371,7 +371,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -447,7 +447,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -549,7 +549,7 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -578,7 +578,7 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -733,7 +733,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
@@ -761,7 +761,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@@ -807,7 +807,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -825,7 +825,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -869,7 +869,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -891,7 +891,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, block_type: str):
|
||||
"""
|
||||
@@ -956,7 +956,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -1005,7 +1005,7 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@@ -1043,7 +1043,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
@@ -1083,7 +1083,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@@ -1118,7 +1118,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@@ -208,7 +208,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@@ -207,7 +207,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -305,7 +305,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -349,7 +349,7 @@ class WorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -397,7 +397,7 @@ class WorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -434,7 +434,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
@@ -458,7 +458,7 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
|
||||
@@ -2,10 +2,10 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -32,14 +32,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(console_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@@ -32,18 +33,6 @@ class ConversationListQuery(BaseModel):
|
||||
pinned: bool | None = None
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
@@ -25,7 +26,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
@@ -44,17 +44,6 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = None
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class MoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"]
|
||||
|
||||
|
||||
@@ -1,28 +1,18 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -34,12 +33,6 @@ from .. import console_ns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
register_schema_model(console_ns, WorkflowRunPayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_file_service import SandboxFileService
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class SandboxFileListQuery(BaseModel):
|
||||
path: str | None = Field(default=None, description="Workspace relative path")
|
||||
recursive: bool = Field(default=False, description="List recursively")
|
||||
|
||||
|
||||
class SandboxFileDownloadRequest(BaseModel):
|
||||
path: str = Field(..., description="Workspace relative file path")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
SandboxFileListQuery.__name__,
|
||||
SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
SandboxFileDownloadRequest.__name__,
|
||||
SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
SANDBOX_FILE_NODE_FIELDS = {
|
||||
"path": fields.String,
|
||||
"is_dir": fields.Boolean,
|
||||
"size": fields.Raw,
|
||||
"mtime": fields.Raw,
|
||||
"extension": fields.String,
|
||||
}
|
||||
|
||||
|
||||
SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = {
|
||||
"download_url": fields.String,
|
||||
"expires_in": fields.Integer,
|
||||
"export_id": fields.String,
|
||||
}
|
||||
|
||||
|
||||
sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS)
|
||||
sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files")
|
||||
class SandboxFilesApi(Resource):
|
||||
"""List sandbox files for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
|
||||
@console_ns.marshal_list_with(sandbox_file_node_model)
|
||||
def get(self, app_id: str):
|
||||
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
return jsonable_encoder(
|
||||
SandboxFileService.list_files(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
sandbox_id=sandbox_id,
|
||||
path=args.path,
|
||||
recursive=args.recursive,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<string:app_id>/sandbox/files/download")
|
||||
class SandboxFileDownloadApi(Resource):
|
||||
"""Download a sandbox file for the current user.
|
||||
|
||||
The sandbox_id is derived from the current user's ID, as each user has
|
||||
their own sandbox workspace per app.
|
||||
"""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
|
||||
@console_ns.marshal_with(sandbox_file_download_ticket_model)
|
||||
def post(self, app_id: str):
|
||||
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
|
||||
account, tenant_id = current_account_with_tenant()
|
||||
sandbox_id = account.id
|
||||
res = SandboxFileService.download_file(
|
||||
tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path
|
||||
)
|
||||
return jsonable_encoder(res)
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
9. skill_file_active
|
||||
10. skill_sync_request
|
||||
11. skill_resync_request
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("skill_event")
|
||||
def handle_skill_event(sid, data):
|
||||
"""
|
||||
Handle skill events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_skill_event(sid, data)
|
||||
@@ -1,67 +0,0 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
class DSLPredictRequest(BaseModel):
|
||||
app_id: str
|
||||
current_node_id: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = DSLPredictRequest.model_validate(request.get_json())
|
||||
|
||||
app_id: str = args.app_id
|
||||
current_node_id: str = args.current_node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
@@ -1,104 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxProviderConfigRequest(BaseModel):
|
||||
config: dict
|
||||
activate: bool = False
|
||||
|
||||
|
||||
class SandboxProviderActivateRequest(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-providers")
|
||||
class SandboxProviderListApi(Resource):
|
||||
@console_ns.doc("list_sandbox_providers")
|
||||
@console_ns.doc(description="Get list of available sandbox providers with configuration status")
|
||||
@console_ns.response(200, "Success", fields.List(fields.Raw(description="Sandbox provider information")))
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
providers = SandboxProviderService.list_providers(current_tenant_id)
|
||||
return jsonable_encoder([p.model_dump() for p in providers])
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/config")
|
||||
class SandboxProviderConfigApi(Resource):
|
||||
@console_ns.doc("save_sandbox_provider_config")
|
||||
@console_ns.doc(description="Save or update configuration for a sandbox provider")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
args = SandboxProviderConfigRequest.model_validate(request.get_json())
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.save_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
config=args.config,
|
||||
activate=args.activate,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
@console_ns.doc("delete_sandbox_provider_config")
|
||||
@console_ns.doc(description="Delete configuration for a sandbox provider")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider_type: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
result = SandboxProviderService.delete_config(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/sandbox-provider/<string:provider_type>/activate")
|
||||
class SandboxProviderActivateApi(Resource):
|
||||
"""Activate a sandbox provider."""
|
||||
|
||||
@console_ns.doc("activate_sandbox_provider")
|
||||
@console_ns.doc(description="Activate a sandbox provider for the current workspace")
|
||||
@console_ns.response(200, "Success")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider_type: str):
|
||||
"""Activate a sandbox provider."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
args = SandboxProviderActivateRequest.model_validate(request.get_json())
|
||||
result = SandboxProviderService.activate_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_type=provider_type,
|
||||
type=args.type,
|
||||
)
|
||||
return result
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
@@ -194,7 +194,7 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
@@ -258,7 +258,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@@ -2,11 +2,12 @@ from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -34,18 +35,6 @@ class ConversationListQuery(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="New conversation name (required if auto_generate is false)")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
@@ -109,7 +98,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@@ -153,7 +142,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -182,7 +171,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -224,7 +213,7 @@ class ConversationVariablesApi(Resource):
|
||||
"""
|
||||
# conversational variable only for chat app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -263,7 +252,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
The value must match the variable's expected type.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -7,6 +6,7 @@ from pydantic import BaseModel, Field, TypeAdapter
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload, MessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
@@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
@@ -27,17 +26,6 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
@@ -65,7 +53,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@@ -170,7 +158,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload as WorkflowRunPayloadBase
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
@@ -46,9 +47,7 @@ from services.workflow_app_service import WorkflowAppService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
class WorkflowRunPayload(WorkflowRunPayloadBase):
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import field_validator
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.controller_schemas import TextToAudioPayload as TextToAudioPayloadBase
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
@@ -34,12 +35,7 @@ from services.errors.audio import (
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
class TextToAudioPayload(TextToAudioPayloadBase):
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import ConversationRenamePayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotChatAppError
|
||||
@@ -37,18 +38,6 @@ class ConversationListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str | None = None
|
||||
auto_generate: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_name_requirement(self):
|
||||
if not self.auto_generate:
|
||||
if self.name is None or not self.name.strip():
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
return self
|
||||
|
||||
|
||||
register_schema_models(web_ns, ConversationListQuery, ConversationRenamePayload)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.controller_schemas import MessageFeedbackPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@@ -53,11 +54,6 @@ class MessageListQuery(BaseModel):
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class MessageMoreLikeThisQuery(BaseModel):
|
||||
response_mode: Literal["blocking", "streaming"] = Field(
|
||||
description="Response mode",
|
||||
|
||||
@@ -1,27 +1,17 @@
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from pydantic import TypeAdapter
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.controller_schemas import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(web_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.common.controller_schemas import WorkflowRunPayload
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
@@ -30,12 +29,6 @@ from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the workflow")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed by the workflow")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
register_schema_models(web_ns, WorkflowRunPayload)
|
||||
|
||||
@@ -1,380 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id: str | None = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@@ -1,5 +1,3 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -94,80 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
Carries IDs and metadata needed for tracing, auditing, and correlation
|
||||
but not part of the core business logic.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
node_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""Structured log entry for agent execution tracing."""
|
||||
|
||||
class LogType(StrEnum):
|
||||
ROUND = "round"
|
||||
THOUGHT = "thought"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
label: str = Field(...)
|
||||
log_type: LogType = Field(...)
|
||||
parent_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
status: LogStatus = Field(...)
|
||||
data: Mapping[str, Any] = Field(...)
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={})
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""Agent execution result."""
|
||||
|
||||
text: str = Field(default="")
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
usage: Any | None = Field(default=None)
|
||||
finish_reason: str | None = Field(default=None)
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
@@ -1,506 +0,0 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p
|
||||
for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
@@ -1,358 +0,0 @@
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if not isinstance(chunks, LLMResult):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None, False
|
||||
@@ -1,418 +0,0 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
finish_reason=None,
|
||||
),
|
||||
system_fingerprint=result.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file.models import File
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@@ -177,14 +177,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# Resolve parent_message_id for thread continuity
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
parent_message_id: str | None = UUID_NIL
|
||||
else:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if not parent_message_id and conversation:
|
||||
parent_message_id = self._resolve_latest_message_id(conversation.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -196,7 +188,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=parent_message_id,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
@@ -697,17 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _resolve_latest_message_id(conversation_id: str) -> str | None:
|
||||
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Message.id)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = db.session.scalar(stmt)
|
||||
return str(latest_id) if latest_id else None
|
||||
|
||||
@@ -246,10 +246,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
if hasattr(self, '_sandbox') and self._sandbox is not None:
|
||||
from core.app.layers.sandbox_layer import SandboxLayer
|
||||
workflow_entry.graph_engine.layer(SandboxLayer(self._sandbox))
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@@ -191,7 +194,22 @@ class AgentChatAppRunner(AppRunner):
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner = AgentAppRunner(
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Legacy Response Adapter for transparent upgrade.
|
||||
|
||||
When old apps (chat/completion/agent-chat) run through the Agent V2
|
||||
workflow engine via transparent upgrade, the SSE events are in workflow
|
||||
format (workflow_started, node_started, etc.). This adapter filters out
|
||||
workflow-specific events and passes through only the events that old
|
||||
clients expect (message, message_end, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOW_ONLY_EVENTS = frozenset({
|
||||
"workflow_started",
|
||||
"workflow_finished",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"iteration_started",
|
||||
"iteration_next",
|
||||
"iteration_completed",
|
||||
})
|
||||
|
||||
|
||||
def adapt_workflow_stream_for_legacy(
|
||||
stream: Generator[str, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Filter workflow-specific SSE events from a streaming response.
|
||||
|
||||
Passes through message, message_end, agent_log, error, ping events.
|
||||
Suppresses workflow_started, workflow_finished, node_started, node_finished.
|
||||
|
||||
This makes the SSE stream look more like what old easy-UI apps produce,
|
||||
while still carrying the actual LLM response content.
|
||||
"""
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.strip():
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
try:
|
||||
if chunk.startswith("data: "):
|
||||
data = json.loads(chunk[6:])
|
||||
event = data.get("event", "")
|
||||
if event in WORKFLOW_ONLY_EVENTS:
|
||||
continue
|
||||
yield chunk
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
yield chunk
|
||||
@@ -170,10 +170,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
for layer in self._graph_engine_layers:
|
||||
workflow_entry.graph_engine.layer(layer)
|
||||
|
||||
if hasattr(self, '_sandbox') and self._sandbox is not None:
|
||||
from core.app.layers.sandbox_layer import SandboxLayer
|
||||
workflow_entry.graph_engine.layer(SandboxLayer(self._sandbox))
|
||||
|
||||
generator = workflow_entry.run()
|
||||
|
||||
for event in generator:
|
||||
|
||||
@@ -104,89 +104,6 @@ class WorkflowBasedAppRunner:
|
||||
return UserFrom.ACCOUNT
|
||||
return UserFrom.END_USER
|
||||
|
||||
@staticmethod
|
||||
def _resolve_sandbox_context(tenant_id: str, user_id: str, app_id: str) -> dict[str, Any] | None:
|
||||
"""Create a sandbox and inject it into run_context if a provider is configured
|
||||
AND the DifyCli binary is available for the current platform."""
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import DIFY_SANDBOX_CONTEXT_KEY
|
||||
from core.sandbox.bash.dify_cli import DifyCliLocator
|
||||
from core.sandbox.builder import SandboxBuilder
|
||||
from core.sandbox.entities.sandbox_type import SandboxType
|
||||
from core.sandbox.storage.noop_storage import NoopSandboxStorage
|
||||
from core.virtual_environment.__base.entities import Arch, OperatingSystem
|
||||
from platform import machine, system as os_system
|
||||
from services.sandbox.sandbox_provider_service import SandboxProviderService
|
||||
|
||||
provider = SandboxProviderService.get_sandbox_provider(tenant_id)
|
||||
sandbox_type = SandboxType(provider.provider_type)
|
||||
|
||||
if sandbox_type == SandboxType.LOCAL:
|
||||
logger.debug("[SANDBOX] Local provider not supported under gevent worker, skipping")
|
||||
return None
|
||||
|
||||
os_name = os_system().lower()
|
||||
arch_name = machine().lower()
|
||||
os_enum = OperatingSystem.LINUX if os_name == "linux" else OperatingSystem.DARWIN
|
||||
arch_enum = Arch.ARM64 if arch_name in ("arm64", "aarch64") else Arch.AMD64
|
||||
cli_binary = DifyCliLocator().resolve(os_enum, arch_enum)
|
||||
|
||||
# Also resolve linux binary for Docker containers
|
||||
cli_binary_linux = None
|
||||
if os_name != "linux":
|
||||
try:
|
||||
cli_binary_linux = DifyCliLocator().resolve(OperatingSystem.LINUX, arch_enum)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
from core.sandbox.builder import _get_sandbox_class
|
||||
from core.virtual_environment.__base.helpers import submit_command, with_connection, pipeline
|
||||
vm_class = _get_sandbox_class(SandboxType(provider.provider_type))
|
||||
vm = vm_class(
|
||||
tenant_id=tenant_id,
|
||||
options=provider.config or {},
|
||||
environments={},
|
||||
user_id=user_id,
|
||||
)
|
||||
vm.open_enviroment()
|
||||
|
||||
from core.sandbox.sandbox import Sandbox
|
||||
sandbox = Sandbox(
|
||||
vm=vm,
|
||||
storage=NoopSandboxStorage(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
assets_id=app_id,
|
||||
)
|
||||
|
||||
from core.sandbox.entities.config import DifyCli as DifyCliPaths
|
||||
from io import BytesIO
|
||||
cli_paths = DifyCliPaths(sandbox.id)
|
||||
vm_binary = cli_binary_linux if (vm.metadata.os == OperatingSystem.LINUX and cli_binary_linux) else cli_binary
|
||||
with open(vm_binary.path, "rb") as f:
|
||||
pipeline(vm).add(["mkdir", "-p", cli_paths.bin_dir]).execute(raise_on_error=True)
|
||||
vm.upload_file(cli_paths.bin_path, BytesIO(f.read()))
|
||||
with with_connection(vm) as conn:
|
||||
submit_command(vm, conn, ["chmod", "+x", cli_paths.bin_path]).result(timeout=10)
|
||||
logger.info("[SANDBOX] CLI binary uploaded to container: %s", cli_paths.bin_path)
|
||||
|
||||
sandbox.mount()
|
||||
sandbox.mark_ready()
|
||||
|
||||
logger.info("[SANDBOX] Created sandbox for tenant=%s, provider=%s", tenant_id, provider.provider_type)
|
||||
return {DIFY_SANDBOX_CONTEXT_KEY: sandbox}
|
||||
except FileNotFoundError:
|
||||
logger.debug("[SANDBOX] DifyCli binary not found, skipping sandbox creation")
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("[SANDBOX] Failed to create sandbox", exc_info=True)
|
||||
return None
|
||||
|
||||
def _build_sandbox_layer(self) -> GraphEngineLayer | None:
|
||||
"""Build a SandboxLayer if sandbox exists in _graph_engine_layers context."""
|
||||
return None
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
@@ -210,13 +127,7 @@ class WorkflowBasedAppRunner:
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
extra_context = self._resolve_sandbox_context(tenant_id or "", user_id, self._app_id)
|
||||
if extra_context:
|
||||
from core.app.entities.app_invoke_entities import DIFY_SANDBOX_CONTEXT_KEY
|
||||
self._sandbox = extra_context.get(DIFY_SANDBOX_CONTEXT_KEY)
|
||||
else:
|
||||
self._sandbox = None
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
@@ -226,11 +137,12 @@ class WorkflowBasedAppRunner:
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
extra_context=extra_context,
|
||||
),
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -1,352 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AssetNodeType(StrEnum):
|
||||
FILE = "file"
|
||||
FOLDER = "folder"
|
||||
|
||||
|
||||
class AppAssetNode(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: AssetNodeType = Field(description="Type of node: file or folder")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
parent_id: str | None = Field(default=None, description="Parent folder ID, None for root level")
|
||||
order: int = Field(default=0, description="Sort order within parent folder, lower values first")
|
||||
extension: str = Field(default="", description="File extension without dot, empty for folders")
|
||||
size: int = Field(default=0, description="File size in bytes, 0 for folders")
|
||||
|
||||
@classmethod
|
||||
def create_folder(cls, node_id: str, name: str, parent_id: str | None = None) -> AppAssetNode:
|
||||
return cls(id=node_id, node_type=AssetNodeType.FOLDER, name=name, parent_id=parent_id)
|
||||
|
||||
@classmethod
|
||||
def create_file(cls, node_id: str, name: str, parent_id: str | None = None, size: int = 0) -> AppAssetNode:
|
||||
return cls(
|
||||
id=node_id,
|
||||
node_type=AssetNodeType.FILE,
|
||||
name=name,
|
||||
parent_id=parent_id,
|
||||
extension=name.rsplit(".", 1)[-1] if "." in name else "",
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class AppAssetNodeView(BaseModel):
|
||||
id: str = Field(description="Unique identifier for the node")
|
||||
node_type: str = Field(description="Type of node: 'file' or 'folder'")
|
||||
name: str = Field(description="Name of the file or folder")
|
||||
path: str = Field(description="Full path from root, e.g. '/folder/file.txt'")
|
||||
extension: str = Field(default="", description="File extension without dot")
|
||||
size: int = Field(default=0, description="File size in bytes")
|
||||
children: list[AppAssetNodeView] = Field(default_factory=list, description="Child nodes for folders")
|
||||
|
||||
|
||||
class BatchUploadNode(BaseModel):
|
||||
"""Structure for batch upload_url tree nodes, used for both input and output."""
|
||||
|
||||
name: str
|
||||
node_type: AssetNodeType
|
||||
size: int = 0
|
||||
children: list[BatchUploadNode] = []
|
||||
id: str | None = None
|
||||
upload_url: str | None = None
|
||||
|
||||
def to_app_asset_nodes(self, parent_id: str | None = None) -> list[AppAssetNode]:
|
||||
"""
|
||||
Generate IDs when missing and convert to AppAssetNode list.
|
||||
Mutates self to set id field when it is not set.
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
self.id = self.id or str(uuid4())
|
||||
nodes: list[AppAssetNode] = []
|
||||
|
||||
if self.node_type == AssetNodeType.FOLDER:
|
||||
nodes.append(AppAssetNode.create_folder(self.id, self.name, parent_id))
|
||||
for child in self.children:
|
||||
nodes.extend(child.to_app_asset_nodes(self.id))
|
||||
else:
|
||||
nodes.append(AppAssetNode.create_file(self.id, self.name, parent_id, self.size))
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
class TreeNodeNotFoundError(Exception):
|
||||
"""Tree internal: node not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeParentNotFoundError(Exception):
|
||||
"""Tree internal: parent folder not found"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreePathConflictError(Exception):
|
||||
"""Tree internal: path already exists"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AppAssetFileTree(BaseModel):
|
||||
"""
|
||||
File tree structure for app assets using adjacency list pattern.
|
||||
|
||||
Design:
|
||||
- Storage: Flat list with parent_id references (adjacency list)
|
||||
- Path: Computed dynamically via get_path(), not stored
|
||||
- Order: Integer field for user-defined sorting within each folder
|
||||
- API response: transform() builds nested tree with computed paths
|
||||
|
||||
Why adjacency list over nested tree or materialized path:
|
||||
- Simpler CRUD: move/rename only updates one node's parent_id
|
||||
- No path cascade: renaming parent doesn't require updating all descendants
|
||||
- JSON-friendly: flat list serializes cleanly to database JSON column
|
||||
- Trade-off: path lookup is O(depth), acceptable for typical file trees
|
||||
"""
|
||||
|
||||
nodes: list[AppAssetNode] = Field(default_factory=list, description="Flat list of all nodes in the tree")
|
||||
|
||||
def ensure_unique_name(
|
||||
self,
|
||||
parent_id: str | None,
|
||||
name: str,
|
||||
*,
|
||||
is_file: bool,
|
||||
extra_taken: set[str] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Return a sibling-unique name by appending numeric suffixes when needed.
|
||||
|
||||
The suffix format is " <n>" (e.g. "report 1", "report 2"). For files,
|
||||
the suffix is inserted before the extension.
|
||||
"""
|
||||
taken = extra_taken or set()
|
||||
if not self.has_child_named(parent_id, name) and name not in taken:
|
||||
return name
|
||||
suffix_index = 1
|
||||
while True:
|
||||
candidate = self._apply_name_suffix(name, suffix_index, is_file=is_file)
|
||||
if not self.has_child_named(parent_id, candidate) and candidate not in taken:
|
||||
return candidate
|
||||
suffix_index += 1
|
||||
|
||||
@staticmethod
|
||||
def _apply_name_suffix(name: str, suffix_index: int, *, is_file: bool) -> str:
|
||||
if not is_file:
|
||||
return f"{name} {suffix_index}"
|
||||
stem, extension = os.path.splitext(name)
|
||||
return f"{stem} {suffix_index}{extension}"
|
||||
|
||||
def get(self, node_id: str) -> AppAssetNode | None:
|
||||
return next((n for n in self.nodes if n.id == node_id), None)
|
||||
|
||||
def get_children(self, parent_id: str | None) -> list[AppAssetNode]:
|
||||
return [n for n in self.nodes if n.parent_id == parent_id]
|
||||
|
||||
def has_child_named(self, parent_id: str | None, name: str) -> bool:
|
||||
return any(n.name == name and n.parent_id == parent_id for n in self.nodes)
|
||||
|
||||
def get_path(self, node_id: str) -> str:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
parts: list[str] = []
|
||||
current: AppAssetNode | None = node
|
||||
while current:
|
||||
parts.append(current.name)
|
||||
current = self.get(current.parent_id) if current.parent_id else None
|
||||
return "/".join(reversed(parts))
|
||||
|
||||
def relative_path(self, a: AppAssetNode, b: AppAssetNode) -> str:
|
||||
"""
|
||||
Calculate relative path from node a to node b for Markdown references.
|
||||
Path is computed from a's parent directory (where the file resides).
|
||||
|
||||
Examples:
|
||||
/foo/a.md -> /foo/b.md => ./b.md
|
||||
/foo/a.md -> /foo/sub/b.md => ./sub/b.md
|
||||
/foo/sub/a.md -> /foo/b.md => ../b.md
|
||||
/foo/sub/deep/a.md -> /foo/b.md => ../../b.md
|
||||
"""
|
||||
|
||||
def get_ancestor_ids(node_id: str | None) -> list[str]:
|
||||
chain: list[str] = []
|
||||
current_id = node_id
|
||||
while current_id:
|
||||
chain.append(current_id)
|
||||
node = self.get(current_id)
|
||||
current_id = node.parent_id if node else None
|
||||
return chain
|
||||
|
||||
a_dir_ancestors = get_ancestor_ids(a.parent_id)
|
||||
b_ancestors = [b.id] + get_ancestor_ids(b.parent_id)
|
||||
a_dir_set = set(a_dir_ancestors)
|
||||
|
||||
lca_id: str | None = None
|
||||
lca_index_in_b = -1
|
||||
for idx, ancestor_id in enumerate(b_ancestors):
|
||||
if ancestor_id in a_dir_set or (a.parent_id is None and b_ancestors[idx:] == []):
|
||||
lca_id = ancestor_id
|
||||
lca_index_in_b = idx
|
||||
break
|
||||
|
||||
if a.parent_id is None:
|
||||
steps_up = 0
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
elif lca_id is None:
|
||||
steps_up = len(a_dir_ancestors)
|
||||
lca_index_in_b = len(b_ancestors)
|
||||
else:
|
||||
steps_up = 0
|
||||
for ancestor_id in a_dir_ancestors:
|
||||
if ancestor_id == lca_id:
|
||||
break
|
||||
steps_up += 1
|
||||
|
||||
path_down: list[str] = []
|
||||
for i in range(lca_index_in_b - 1, -1, -1):
|
||||
node = self.get(b_ancestors[i])
|
||||
if node:
|
||||
path_down.append(node.name)
|
||||
|
||||
if steps_up == 0:
|
||||
return "./" + "/".join(path_down)
|
||||
|
||||
parts: list[str] = [".."] * steps_up + path_down
|
||||
return "/".join(parts)
|
||||
|
||||
def get_descendant_ids(self, node_id: str) -> list[str]:
|
||||
result: list[str] = []
|
||||
stack = [node_id]
|
||||
while stack:
|
||||
current_id = stack.pop()
|
||||
for child in self.nodes:
|
||||
if child.parent_id == current_id:
|
||||
result.append(child.id)
|
||||
stack.append(child.id)
|
||||
return result
|
||||
|
||||
def add(self, node: AppAssetNode) -> AppAssetNode:
|
||||
if self.get(node.id):
|
||||
raise TreePathConflictError(node.id)
|
||||
if self.has_child_named(node.parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
if node.parent_id:
|
||||
parent = self.get(node.parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(node.parent_id)
|
||||
siblings = self.get_children(node.parent_id)
|
||||
node.order = max((s.order for s in siblings), default=-1) + 1
|
||||
self.nodes.append(node)
|
||||
return node
|
||||
|
||||
def update(self, node_id: str, size: int) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node or node.node_type != AssetNodeType.FILE:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
node.size = size
|
||||
return node
|
||||
|
||||
def rename(self, node_id: str, new_name: str) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if node.name != new_name and self.has_child_named(node.parent_id, new_name):
|
||||
raise TreePathConflictError(new_name)
|
||||
node.name = new_name
|
||||
if node.node_type == AssetNodeType.FILE:
|
||||
node.extension = new_name.rsplit(".", 1)[-1] if "." in new_name else ""
|
||||
return node
|
||||
|
||||
def move(self, node_id: str, new_parent_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
if new_parent_id:
|
||||
parent = self.get(new_parent_id)
|
||||
if not parent or parent.node_type != AssetNodeType.FOLDER:
|
||||
raise TreeParentNotFoundError(new_parent_id)
|
||||
if self.has_child_named(new_parent_id, node.name):
|
||||
raise TreePathConflictError(node.name)
|
||||
node.parent_id = new_parent_id
|
||||
siblings = self.get_children(new_parent_id)
|
||||
node.order = max((s.order for s in siblings if s.id != node_id), default=-1) + 1
|
||||
return node
|
||||
|
||||
def reorder(self, node_id: str, after_node_id: str | None) -> AppAssetNode:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
|
||||
siblings = sorted(self.get_children(node.parent_id), key=lambda x: x.order)
|
||||
siblings = [s for s in siblings if s.id != node_id]
|
||||
|
||||
if after_node_id is None:
|
||||
insert_idx = 0
|
||||
else:
|
||||
after_node = self.get(after_node_id)
|
||||
if not after_node or after_node.parent_id != node.parent_id:
|
||||
raise TreeNodeNotFoundError(after_node_id)
|
||||
insert_idx = next((i for i, s in enumerate(siblings) if s.id == after_node_id), -1) + 1
|
||||
|
||||
siblings.insert(insert_idx, node)
|
||||
for idx, sibling in enumerate(siblings):
|
||||
sibling.order = idx
|
||||
|
||||
return node
|
||||
|
||||
def remove(self, node_id: str) -> list[str]:
|
||||
node = self.get(node_id)
|
||||
if not node:
|
||||
raise TreeNodeNotFoundError(node_id)
|
||||
ids_to_remove = [node_id] + self.get_descendant_ids(node_id)
|
||||
self.nodes = [n for n in self.nodes if n.id not in ids_to_remove]
|
||||
return ids_to_remove
|
||||
|
||||
def walk_files(self) -> Generator[AppAssetNode, None, None]:
|
||||
return (n for n in self.nodes if n.node_type == AssetNodeType.FILE)
|
||||
|
||||
def transform(self) -> list[AppAssetNodeView]:
|
||||
by_parent: dict[str | None, list[AppAssetNode]] = defaultdict(list)
|
||||
for n in self.nodes:
|
||||
by_parent[n.parent_id].append(n)
|
||||
|
||||
for children in by_parent.values():
|
||||
children.sort(key=lambda x: x.order)
|
||||
|
||||
paths: dict[str, str] = {}
|
||||
tree_views: dict[str, AppAssetNodeView] = {}
|
||||
|
||||
def build_view(node: AppAssetNode, parent_path: str) -> None:
|
||||
path = f"{parent_path}/{node.name}"
|
||||
paths[node.id] = path
|
||||
child_views: list[AppAssetNodeView] = []
|
||||
for child in by_parent.get(node.id, []):
|
||||
build_view(child, path)
|
||||
child_views.append(tree_views[child.id])
|
||||
tree_views[node.id] = AppAssetNodeView(
|
||||
id=node.id,
|
||||
node_type=node.node_type.value,
|
||||
name=node.name,
|
||||
path=path,
|
||||
extension=node.extension,
|
||||
size=node.size,
|
||||
children=child_views,
|
||||
)
|
||||
|
||||
for root_node in by_parent.get(None, []):
|
||||
build_view(root_node, "")
|
||||
|
||||
return [tree_views[n.id] for n in by_parent.get(None, [])]
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.nodes) == 0
|
||||
@@ -1,96 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
|
||||
# Constants
|
||||
BUNDLE_DSL_FILENAME_PATTERN = re.compile(r"^[^/]+\.ya?ml$")
|
||||
BUNDLE_MAX_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
MANIFEST_FILENAME = "manifest.json"
|
||||
MANIFEST_SCHEMA_VERSION = "1.0"
|
||||
|
||||
|
||||
# Exceptions
|
||||
class BundleFormatError(Exception):
|
||||
"""Raised when bundle format is invalid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ZipSecurityError(Exception):
|
||||
"""Raised when zip file contains security violations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Manifest DTOs
|
||||
class ManifestFileEntry(BaseModel):
|
||||
"""Maps node_id to file path in the bundle."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
node_id: str
|
||||
path: str
|
||||
|
||||
|
||||
class ManifestIntegrity(BaseModel):
|
||||
"""Basic integrity check fields."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
file_count: int
|
||||
|
||||
|
||||
class ManifestAppAssets(BaseModel):
|
||||
"""App assets section containing the full tree."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
tree: AppAssetFileTree
|
||||
|
||||
|
||||
class BundleManifest(BaseModel):
|
||||
"""
|
||||
Bundle manifest for app asset import/export.
|
||||
|
||||
Schema version 1.0:
|
||||
- dsl_filename: DSL file name in bundle root (e.g. "my_app.yml")
|
||||
- tree: Full AppAssetFileTree (files + folders) for 100% restoration including node IDs
|
||||
- files: Explicit node_id -> path mapping for file nodes only
|
||||
- integrity: Basic file_count validation
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
schema_version: str = Field(default=MANIFEST_SCHEMA_VERSION)
|
||||
generated_at: datetime = Field(default_factory=lambda: datetime.now(tz=UTC))
|
||||
dsl_filename: str = Field(description="DSL file name in bundle root")
|
||||
app_assets: ManifestAppAssets
|
||||
files: list[ManifestFileEntry]
|
||||
integrity: ManifestIntegrity
|
||||
|
||||
@property
|
||||
def assets_prefix(self) -> str:
|
||||
"""Assets directory name (DSL filename without extension)."""
|
||||
return self.dsl_filename.rsplit(".", 1)[0]
|
||||
|
||||
@classmethod
|
||||
def from_tree(cls, tree: AppAssetFileTree, dsl_filename: str) -> BundleManifest:
|
||||
"""Build manifest from an AppAssetFileTree."""
|
||||
files = [ManifestFileEntry(node_id=n.id, path=tree.get_path(n.id)) for n in tree.walk_files()]
|
||||
return cls(
|
||||
dsl_filename=dsl_filename,
|
||||
app_assets=ManifestAppAssets(tree=tree),
|
||||
files=files,
|
||||
integrity=ManifestIntegrity(file_count=len(files)),
|
||||
)
|
||||
|
||||
|
||||
# Export result
|
||||
class BundleExportResult(BaseModel):
|
||||
download_url: str = Field(description="Temporary download URL for the ZIP")
|
||||
filename: str = Field(description="Suggested filename for the ZIP")
|
||||
@@ -46,9 +46,6 @@ class InvokeFrom(StrEnum):
|
||||
return source_mapping.get(self, "dev")
|
||||
|
||||
|
||||
DIFY_SANDBOX_CONTEXT_KEY = "_dify_sandbox"
|
||||
|
||||
|
||||
class DifyRunContext(BaseModel):
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
import logging
|
||||
|
||||
from core.sandbox import Sandbox
|
||||
from graphon.graph_engine.layers.base import GraphEngineLayer
|
||||
from graphon.graph_events.base import GraphEngineEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxLayer(GraphEngineLayer):
|
||||
def __init__(self, sandbox: Sandbox) -> None:
|
||||
super().__init__()
|
||||
self._sandbox = sandbox
|
||||
|
||||
def on_graph_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
pass
|
||||
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
self._sandbox.release()
|
||||
@@ -1,13 +0,0 @@
|
||||
from .constants import AppAssetsAttrs
|
||||
from .entities import (
|
||||
AssetItem,
|
||||
SkillAsset,
|
||||
)
|
||||
from .storage import AssetPaths
|
||||
|
||||
__all__ = [
|
||||
"AppAssetsAttrs",
|
||||
"AssetItem",
|
||||
"AssetPaths",
|
||||
"SkillAsset",
|
||||
]
|
||||
@@ -1,180 +0,0 @@
|
||||
"""Unified content accessor for app asset nodes.
|
||||
|
||||
Accessor is scoped to a single app (tenant_id + app_id), not a single node.
|
||||
All methods accept an AppAssetNode parameter to identify the target.
|
||||
|
||||
CachedContentAccessor is the primary entry point:
|
||||
- Reads DB first, misses fall through to S3 with sync backfill.
|
||||
- Writes go to both DB and S3 (dual-write).
|
||||
- resolve_items() batch-enriches AssetItem lists with DB-cached content
|
||||
(extension-agnostic), so callers never need to filter by extension.
|
||||
- Wraps an internal _StorageAccessor for S3 I/O.
|
||||
|
||||
Collaborators:
|
||||
- services.asset_content_service.AssetContentService (DB layer)
|
||||
- core.app_assets.storage.AssetPaths (S3 key generation)
|
||||
- extensions.storage.cached_presign_storage.CachedPresignStorage (S3 I/O)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetNode
|
||||
from core.app_assets.entities.assets import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
from extensions.storage.cached_presign_storage import CachedPresignStorage
|
||||
from services.asset_content_service import AssetContentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# S3-only implementation (internal, used as inner delegate)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StorageAccessor:
|
||||
"""Reads/writes draft content via object storage (S3) only."""
|
||||
|
||||
_storage: CachedPresignStorage
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
|
||||
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
|
||||
self._storage = storage
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def _key(self, node: AppAssetNode) -> str:
|
||||
return AssetPaths.draft(self._tenant_id, self._app_id, node.id)
|
||||
|
||||
def load(self, node: AppAssetNode) -> bytes:
|
||||
return self._storage.load_once(self._key(node))
|
||||
|
||||
def save(self, node: AppAssetNode, content: bytes) -> None:
|
||||
self._storage.save(self._key(node), content)
|
||||
|
||||
def delete(self, node: AppAssetNode) -> None:
|
||||
try:
|
||||
self._storage.delete(self._key(node))
|
||||
except Exception:
|
||||
logger.warning("Failed to delete storage key %s", self._key(node), exc_info=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB-cached implementation (the public API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CachedContentAccessor:
|
||||
"""App-level content accessor with DB read-through cache over S3.
|
||||
|
||||
Read path: DB first -> miss -> S3 fallback -> sync backfill DB
|
||||
Write path: DB upsert + S3 save (dual-write)
|
||||
Delete path: DB delete + S3 delete
|
||||
|
||||
bulk_load uses a single SQL query for all nodes, with S3 fallback per miss.
|
||||
|
||||
Usage:
|
||||
accessor = CachedContentAccessor(storage, tenant_id, app_id)
|
||||
content = accessor.load(node)
|
||||
accessor.save(node, content)
|
||||
results = accessor.bulk_load(nodes)
|
||||
"""
|
||||
|
||||
_inner: _StorageAccessor
|
||||
_tenant_id: str
|
||||
_app_id: str
|
||||
|
||||
def __init__(self, storage: CachedPresignStorage, tenant_id: str, app_id: str) -> None:
|
||||
self._inner = _StorageAccessor(storage, tenant_id, app_id)
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def load(self, node: AppAssetNode) -> bytes:
|
||||
# 1. Try DB
|
||||
cached = AssetContentService.get(self._tenant_id, self._app_id, node.id)
|
||||
if cached is not None:
|
||||
return cached.encode("utf-8")
|
||||
|
||||
# 2. Fallback to S3
|
||||
data = self._inner.load(node)
|
||||
|
||||
# 3. Sync backfill DB
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=data.decode("utf-8"),
|
||||
size=len(data),
|
||||
)
|
||||
return data
|
||||
|
||||
def bulk_load(self, nodes: list[AppAssetNode]) -> dict[str, bytes]:
|
||||
"""Single SQL for all nodes, S3 fallback + backfill per miss."""
|
||||
result: dict[str, bytes] = {}
|
||||
node_ids = [n.id for n in nodes]
|
||||
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
|
||||
|
||||
for node in nodes:
|
||||
if node.id in cached:
|
||||
result[node.id] = cached[node.id].encode("utf-8")
|
||||
else:
|
||||
# S3 fallback + sync backfill
|
||||
data = self._inner.load(node)
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=data.decode("utf-8"),
|
||||
size=len(data),
|
||||
)
|
||||
result[node.id] = data
|
||||
return result
|
||||
|
||||
def save(self, node: AppAssetNode, content: bytes) -> None:
|
||||
# Dual-write: DB + S3
|
||||
AssetContentService.upsert(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
node_id=node.id,
|
||||
content=content.decode("utf-8"),
|
||||
size=len(content),
|
||||
)
|
||||
self._inner.save(node, content)
|
||||
|
||||
def resolve_items(self, items: list[AssetItem]) -> list[AssetItem]:
|
||||
"""Batch-enrich asset items with DB-cached content.
|
||||
|
||||
Queries by ``asset_id`` only — extension-agnostic. Items without
|
||||
a DB cache row keep their original *content* value (typically
|
||||
``None``), so only genuinely cached assets (e.g. ``.md`` skill
|
||||
documents) get populated.
|
||||
|
||||
This eliminates the need for callers to filter by file extension
|
||||
before deciding whether to read from the DB cache.
|
||||
"""
|
||||
if not items:
|
||||
return items
|
||||
|
||||
node_ids = [a.asset_id for a in items]
|
||||
cached = AssetContentService.get_many(self._tenant_id, self._app_id, node_ids)
|
||||
|
||||
if not cached:
|
||||
return items
|
||||
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=a.asset_id,
|
||||
path=a.path,
|
||||
file_name=a.file_name,
|
||||
extension=a.extension,
|
||||
storage_key=a.storage_key,
|
||||
content=cached[a.asset_id].encode("utf-8") if a.asset_id in cached else a.content,
|
||||
)
|
||||
for a in items
|
||||
]
|
||||
|
||||
def delete(self, node: AppAssetNode) -> None:
|
||||
AssetContentService.delete(self._tenant_id, self._app_id, node.id)
|
||||
self._inner.delete(node)
|
||||
@@ -1,12 +0,0 @@
|
||||
from .base import AssetBuilder, BuildContext
|
||||
from .file_builder import FileBuilder
|
||||
from .pipeline import AssetBuildPipeline
|
||||
from .skill_builder import SkillBuilder
|
||||
|
||||
__all__ = [
|
||||
"AssetBuildPipeline",
|
||||
"AssetBuilder",
|
||||
"BuildContext",
|
||||
"FileBuilder",
|
||||
"SkillBuilder",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class BuildContext:
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
build_id: str
|
||||
|
||||
|
||||
class AssetBuilder(Protocol):
|
||||
def accept(self, node: AppAssetNode) -> bool: ...
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None: ...
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]: ...
|
||||
@@ -1,30 +0,0 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
|
||||
class FileBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._nodes = []
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return True
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
|
||||
self._nodes.append((node, path))
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=AssetPaths.draft(ctx.tenant_id, ctx.app_id, node.id),
|
||||
)
|
||||
for node, path in self._nodes
|
||||
]
|
||||
@@ -1,27 +0,0 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from core.app_assets.entities import AssetItem
|
||||
|
||||
from .base import AssetBuilder, BuildContext
|
||||
|
||||
|
||||
class AssetBuildPipeline:
|
||||
_builders: list[AssetBuilder]
|
||||
|
||||
def __init__(self, builders: list[AssetBuilder]) -> None:
|
||||
self._builders = builders
|
||||
|
||||
def build_all(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
# 1. Distribute: each node goes to first accepting builder
|
||||
for node in tree.walk_files():
|
||||
path = tree.get_path(node.id)
|
||||
for builder in self._builders:
|
||||
if builder.accept(node):
|
||||
builder.collect(node, path, ctx)
|
||||
break
|
||||
|
||||
# 2. Each builder builds its collected nodes
|
||||
results: list[AssetItem] = []
|
||||
for builder in self._builders:
|
||||
results.extend(builder.build(tree, ctx))
|
||||
|
||||
return results
|
||||
@@ -1,96 +0,0 @@
|
||||
"""Builder that compiles ``.md`` skill documents into resolved content.
|
||||
|
||||
The builder reads raw draft content from the DB-backed accessor, parses
|
||||
each into a ``SkillDocument``, assembles a ``SkillBundle`` (with
|
||||
transitive tool/file dependency resolution), and returns ``AssetItem``
|
||||
objects whose *content* field carries the resolved bytes in-process.
|
||||
|
||||
The assembled ``SkillBundle`` is persisted via ``SkillManager``
|
||||
(S3 + Redis) **and** retained on the ``bundle`` property so that
|
||||
callers (e.g. ``DraftAppAssetsInitializer``) can pass it directly to
|
||||
``sandbox.attrs`` without a redundant Redis/S3 round-trip.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AppAssetNode
|
||||
from core.app_assets.accessor import CachedContentAccessor
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.skill.assembler import SkillBundleAssembler
|
||||
from core.skill.entities.skill_bundle import SkillBundle
|
||||
from core.skill.entities.skill_document import SkillDocument
|
||||
|
||||
from .base import BuildContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillBuilder:
|
||||
_nodes: list[tuple[AppAssetNode, str]]
|
||||
_accessor: CachedContentAccessor
|
||||
_bundle: SkillBundle | None
|
||||
|
||||
def __init__(self, accessor: CachedContentAccessor) -> None:
|
||||
self._nodes = []
|
||||
self._accessor = accessor
|
||||
self._bundle = None
|
||||
|
||||
@property
|
||||
def bundle(self) -> SkillBundle | None:
|
||||
"""The ``SkillBundle`` produced by the last ``build()`` call, or *None*."""
|
||||
return self._bundle
|
||||
|
||||
def accept(self, node: AppAssetNode) -> bool:
|
||||
return node.extension == "md"
|
||||
|
||||
def collect(self, node: AppAssetNode, path: str, ctx: BuildContext) -> None:
|
||||
self._nodes.append((node, path))
|
||||
|
||||
def build(self, tree: AppAssetFileTree, ctx: BuildContext) -> list[AssetItem]:
|
||||
from core.skill.skill_manager import SkillManager
|
||||
|
||||
if not self._nodes:
|
||||
bundle = SkillBundle(assets_id=ctx.build_id, asset_tree=tree)
|
||||
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
|
||||
self._bundle = bundle
|
||||
return []
|
||||
|
||||
# Batch-load all skill draft content in one DB query (with S3 fallback on miss).
|
||||
nodes_only = [node for node, _ in self._nodes]
|
||||
raw_contents = self._accessor.bulk_load(nodes_only)
|
||||
|
||||
# Parse documents — skip nodes whose draft content is still the empty
|
||||
# placeholder written at creation time.
|
||||
documents: dict[str, SkillDocument] = {}
|
||||
for node, _ in self._nodes:
|
||||
try:
|
||||
raw = raw_contents.get(node.id)
|
||||
if not raw:
|
||||
continue
|
||||
data = {"skill_id": node.id, **json.loads(raw)}
|
||||
documents[node.id] = SkillDocument.model_validate(data)
|
||||
except (FileNotFoundError, json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
logger.exception("Failed to load or parse skill document for node %s", node.id)
|
||||
raise ValueError(f"Failed to load or parse skill document for node {node.id}") from e
|
||||
|
||||
bundle = SkillBundleAssembler(tree).assemble_bundle(documents, ctx.build_id)
|
||||
SkillManager.save_bundle(ctx.tenant_id, ctx.app_id, ctx.build_id, bundle)
|
||||
self._bundle = bundle
|
||||
|
||||
items: list[AssetItem] = []
|
||||
for node, path in self._nodes:
|
||||
skill = bundle.get(node.id)
|
||||
if skill is None:
|
||||
continue
|
||||
items.append(
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=path,
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key="",
|
||||
content=skill.content.encode("utf-8"),
|
||||
)
|
||||
)
|
||||
return items
|
||||
@@ -1,8 +0,0 @@
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree
|
||||
from libs.attr_map import AttrKey
|
||||
|
||||
|
||||
class AppAssetsAttrs:
|
||||
# Skill artifact set
|
||||
FILE_TREE = AttrKey("file_tree", AppAssetFileTree)
|
||||
APP_ASSETS_ID = AttrKey("app_assets_id", str)
|
||||
@@ -1,20 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from core.app.entities.app_asset_entities import AppAssetFileTree, AssetNodeType
|
||||
from core.app_assets.entities import AssetItem
|
||||
from core.app_assets.storage import AssetPaths
|
||||
|
||||
|
||||
def tree_to_asset_items(tree: AppAssetFileTree, tenant_id: str, app_id: str) -> list[AssetItem]:
|
||||
"""Convert AppAssetFileTree to list of AssetItem for packaging."""
|
||||
return [
|
||||
AssetItem(
|
||||
asset_id=node.id,
|
||||
path=tree.get_path(node.id),
|
||||
file_name=node.name,
|
||||
extension=node.extension or "",
|
||||
storage_key=AssetPaths.draft(tenant_id, app_id, node.id),
|
||||
)
|
||||
for node in tree.nodes
|
||||
if node.node_type == AssetNodeType.FILE
|
||||
]
|
||||
@@ -1,7 +0,0 @@
|
||||
from .assets import AssetItem
|
||||
from .skill import SkillAsset
|
||||
|
||||
__all__ = [
|
||||
"AssetItem",
|
||||
"SkillAsset",
|
||||
]
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssetItem:
|
||||
"""A single asset file produced by the build pipeline.
|
||||
|
||||
When *content* is set the payload is available in-process and can be
|
||||
written directly into a ZIP or uploaded to a sandbox VM without an
|
||||
extra S3 round-trip. When *content* is ``None`` the caller should
|
||||
fetch the bytes from *storage_key* (the traditional presigned-URL
|
||||
path).
|
||||
"""
|
||||
|
||||
asset_id: str
|
||||
path: str
|
||||
file_name: str
|
||||
extension: str
|
||||
storage_key: str
|
||||
content: bytes | None = field(default=None, repr=False)
|
||||
@@ -1,10 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .assets import AssetItem
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillAsset(AssetItem):
|
||||
metadata: Mapping[str, Any] = field(default_factory=dict)
|
||||
@@ -1,68 +0,0 @@
|
||||
"""App assets storage key generation.
|
||||
|
||||
Provides AssetPaths facade for generating storage keys for app assets.
|
||||
Storage instances are obtained via AppAssetService.get_storage().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
_BASE = "app_assets"
|
||||
|
||||
|
||||
def _check_uuid(value: str, name: str) -> None:
|
||||
try:
|
||||
UUID(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(f"{name} must be a valid UUID") from e
|
||||
|
||||
|
||||
class AssetPaths:
|
||||
"""Facade for generating app asset storage keys."""
|
||||
|
||||
@staticmethod
|
||||
def draft(tenant_id: str, app_id: str, node_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/draft/{node_id}"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(node_id, "node_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/draft/{node_id}"
|
||||
|
||||
@staticmethod
|
||||
def build_zip(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/artifacts/{assets_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(assets_id, "assets_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def skill_bundle(tenant_id: str, app_id: str, assets_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/artifacts/{assets_id}/skill_artifact_set.json"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(assets_id, "assets_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/artifacts/{assets_id}/skill_artifact_set.json"
|
||||
|
||||
@staticmethod
|
||||
def source_zip(tenant_id: str, app_id: str, workflow_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/sources/{workflow_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(workflow_id, "workflow_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/sources/{workflow_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def bundle_export(tenant_id: str, app_id: str, export_id: str) -> str:
|
||||
"""app_assets/{tenant}/{app}/bundle_exports/{export_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
_check_uuid(app_id, "app_id")
|
||||
_check_uuid(export_id, "export_id")
|
||||
return f"{_BASE}/{tenant_id}/{app_id}/bundle_exports/{export_id}.zip"
|
||||
|
||||
@staticmethod
|
||||
def bundle_import(tenant_id: str, import_id: str) -> str:
|
||||
"""app_assets/{tenant}/imports/{import_id}.zip"""
|
||||
_check_uuid(tenant_id, "tenant_id")
|
||||
return f"{_BASE}/{tenant_id}/imports/{import_id}.zip"
|
||||
@@ -1 +0,0 @@
|
||||
# App bundle utilities - manifest-driven import/export handled by AppBundleService
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
"""Upload a DSL file to the Creators Platform anonymous upload endpoint.
|
||||
|
||||
Args:
|
||||
dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP).
|
||||
filename: Original filename for the upload.
|
||||
|
||||
Returns:
|
||||
The claim_code string used to retrieve the DSL later.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the upload request fails.
|
||||
ValueError: If the response does not contain a valid claim_code.
|
||||
"""
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
"""Generate the redirect URL to the Creators Platform frontend.
|
||||
|
||||
Redirects to the Creators Platform root page with the dsl_claim_code.
|
||||
If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud),
|
||||
also signs an OAuth authorization code so the frontend can
|
||||
automatically authenticate the user via the OAuth callback.
|
||||
|
||||
For self-hosted Dify without OAuth client_id configured, only the
|
||||
dsl_claim_code is passed and the user must log in manually.
|
||||
|
||||
Args:
|
||||
user_account_id: The Dify user account ID.
|
||||
claim_code: The claim_code obtained from upload_dsl().
|
||||
|
||||
Returns:
|
||||
The full redirect URL string.
|
||||
"""
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
@@ -1,67 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
"""Output model for suggested questions generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
description="Exactly 3 suggested follow-up questions for the user",
|
||||
)
|
||||
|
||||
|
||||
class VariableSelectorOutput(BaseModel):
|
||||
"""Variable selector mapping code variable to upstream node output.
|
||||
|
||||
Note: Separate from VariableSelector to ensure 'additionalProperties: false'
|
||||
in JSON schema for OpenAI/Azure strict mode.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(description="Variable name used in the generated code")
|
||||
value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeNodeOutputItem(BaseModel):
|
||||
"""Single output variable definition.
|
||||
|
||||
Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and
|
||||
does not support dynamic object keys, so outputs use array format.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(description="Output variable name returned by the main function")
|
||||
type: SegmentType = Field(description="Data type of the output variable")
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
"""Structured output for code node generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelectorOutput] = Field(
|
||||
description="Input variables mapping code variables to upstream node outputs"
|
||||
)
|
||||
code: str = Field(description="Generated code with a main function that processes inputs and returns outputs")
|
||||
outputs: list[CodeNodeOutputItem] = Field(
|
||||
description="Output variable definitions specifying name and type for each return value"
|
||||
)
|
||||
message: str = Field(description="Brief explanation of what the generated code does")
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
"""Output model for instruction-based prompt modification."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str = Field(description="The modified prompt content after applying the instruction")
|
||||
message: str = Field(description="Brief explanation of what changes were made")
|
||||
@@ -1,203 +0,0 @@
|
||||
"""
|
||||
File path detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect sandbox file path fields in JSON Schema (format: "file-path")
|
||||
2. Adapt schemas to add file-path descriptions before model invocation
|
||||
3. Convert sandbox file path strings into File objects via a resolver
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import File
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
FILE_PATH_FORMAT = "file-path"
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox"
|
||||
|
||||
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
"""Check if a schema property represents a sandbox file path."""
|
||||
if schema.get("type") != "string":
|
||||
return False
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format == FILE_PATH_FORMAT
|
||||
|
||||
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""Recursively detect file path fields in a JSON schema."""
|
||||
file_path_fields: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
|
||||
return file_path_fields
|
||||
|
||||
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Normalize sandbox file path fields and collect their JSON paths."""
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
"""Convert sandbox file paths into File objects using the resolver."""
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
|
||||
return result_dict, files
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema["format"] = FILE_PATH_FORMAT
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
if FILE_PATH_DESCRIPTION_SUFFIX not in description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if remaining:
|
||||
for item in target_list:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
else:
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
@@ -1,11 +0,0 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
@@ -1,82 +0,0 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from graphon.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
@@ -1,196 +0,0 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@@ -63,7 +63,7 @@ class TokenBufferMemory:
|
||||
"""
|
||||
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
|
||||
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
@@ -209,7 +209,10 @@ class PluginInstaller(BasePluginClient):
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/decode/from_identifier",
|
||||
PluginDecodeResponse,
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
params={
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"PluginUniqueIdentifier": plugin_unique_identifier, # compat with daemon <= 0.5.4
|
||||
},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.entities.event import DatasourceCompletedEvent, DatasourceErrorEvent, DatasourceProcessingEvent
|
||||
from core.rag.entities.index_entities import EconomySetting, EmbeddingSetting, IndexMethod
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataFilteringCondition, SupportedComparisonOperator
|
||||
from core.rag.entities.processing_entities import ParentMode, PreProcessingRule, Rule, Segmentation
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting
|
||||
from core.rag.entities.retrieval_settings import KeywordSetting, VectorSetting, WeightedScoreConfig
|
||||
|
||||
__all__ = [
|
||||
"Condition",
|
||||
"DatasourceCompletedEvent",
|
||||
"DatasourceErrorEvent",
|
||||
"DatasourceProcessingEvent",
|
||||
"DocumentContext",
|
||||
"EconomySetting",
|
||||
"EmbeddingSetting",
|
||||
"IndexMethod",
|
||||
"KeywordSetting",
|
||||
"MetadataFilteringCondition",
|
||||
"ParentMode",
|
||||
@@ -16,4 +24,5 @@ __all__ = [
|
||||
"Segmentation",
|
||||
"SupportedComparisonOperator",
|
||||
"VectorSetting",
|
||||
"WeightedScoreConfig",
|
||||
]
|
||||
|
||||
30
api/core/rag/entities/index_entities.py
Normal file
30
api/core/rag/entities/index_entities.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
"""
|
||||
Embedding Setting.
|
||||
"""
|
||||
|
||||
embedding_provider_name: str
|
||||
embedding_model_name: str
|
||||
|
||||
|
||||
class EconomySetting(BaseModel):
|
||||
"""
|
||||
Economy Setting.
|
||||
"""
|
||||
|
||||
keyword_number: int
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
"""
|
||||
Knowledge Index Setting.
|
||||
"""
|
||||
|
||||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_setting: EmbeddingSetting
|
||||
economy_setting: EconomySetting
|
||||
@@ -17,3 +17,12 @@ class KeywordSetting(BaseModel):
|
||||
"""
|
||||
|
||||
keyword_weight: float
|
||||
|
||||
|
||||
class WeightedScoreConfig(BaseModel):
|
||||
"""
|
||||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: VectorSetting
|
||||
keyword_setting: KeywordSetting
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .bash.dify_cli import (
|
||||
DifyCliBinary,
|
||||
DifyCliConfig,
|
||||
DifyCliEnvConfig,
|
||||
DifyCliLocator,
|
||||
DifyCliToolConfig,
|
||||
)
|
||||
from .bash.session import SandboxBashSession
|
||||
from .builder import SandboxBuilder, VMConfig
|
||||
from .entities import AppAssets, DifyCli, SandboxProviderApiEntity, SandboxType
|
||||
from .initializer import (
|
||||
AsyncSandboxInitializer,
|
||||
SandboxInitializeContext,
|
||||
SandboxInitializer,
|
||||
SyncSandboxInitializer,
|
||||
)
|
||||
from .initializer.app_assets_initializer import AppAssetsInitializer
|
||||
from .initializer.dify_cli_initializer import DifyCliInitializer
|
||||
from .initializer.draft_app_assets_initializer import DraftAppAssetsInitializer
|
||||
from .sandbox import Sandbox
|
||||
from .storage import ArchiveSandboxStorage, SandboxStorage
|
||||
from .utils.debug import sandbox_debug
|
||||
from .utils.encryption import create_sandbox_config_encrypter, masked_config
|
||||
|
||||
__all__ = [
|
||||
"AppAssets",
|
||||
"AppAssetsInitializer",
|
||||
"ArchiveSandboxStorage",
|
||||
"AsyncSandboxInitializer",
|
||||
"DifyCli",
|
||||
"DifyCliBinary",
|
||||
"DifyCliConfig",
|
||||
"DifyCliEnvConfig",
|
||||
"DifyCliInitializer",
|
||||
"DifyCliLocator",
|
||||
"DifyCliToolConfig",
|
||||
"DraftAppAssetsInitializer",
|
||||
"Sandbox",
|
||||
"SandboxBashSession",
|
||||
"SandboxBuilder",
|
||||
"SandboxInitializeContext",
|
||||
"SandboxInitializer",
|
||||
"SandboxProviderApiEntity",
|
||||
"SandboxStorage",
|
||||
"SandboxType",
|
||||
"SyncSandboxInitializer",
|
||||
"VMConfig",
|
||||
"create_sandbox_config_encrypter",
|
||||
"masked_config",
|
||||
"sandbox_debug",
|
||||
]
|
||||
|
||||
_LAZY_IMPORTS = {
|
||||
"AppAssets": ("core.sandbox.entities", "AppAssets"),
|
||||
"AppAssetsInitializer": ("core.sandbox.initializer.app_assets_initializer", "AppAssetsInitializer"),
|
||||
"AsyncSandboxInitializer": ("core.sandbox.initializer", "AsyncSandboxInitializer"),
|
||||
"ArchiveSandboxStorage": ("core.sandbox.storage", "ArchiveSandboxStorage"),
|
||||
"DifyCli": ("core.sandbox.entities", "DifyCli"),
|
||||
"DifyCliBinary": ("core.sandbox.bash.dify_cli", "DifyCliBinary"),
|
||||
"DifyCliConfig": ("core.sandbox.bash.dify_cli", "DifyCliConfig"),
|
||||
"DifyCliEnvConfig": ("core.sandbox.bash.dify_cli", "DifyCliEnvConfig"),
|
||||
"DifyCliInitializer": ("core.sandbox.initializer.dify_cli_initializer", "DifyCliInitializer"),
|
||||
"DifyCliLocator": ("core.sandbox.bash.dify_cli", "DifyCliLocator"),
|
||||
"DifyCliToolConfig": ("core.sandbox.bash.dify_cli", "DifyCliToolConfig"),
|
||||
"DraftAppAssetsInitializer": ("core.sandbox.initializer.draft_app_assets_initializer", "DraftAppAssetsInitializer"),
|
||||
"Sandbox": ("core.sandbox.sandbox", "Sandbox"),
|
||||
"SandboxBashSession": ("core.sandbox.bash.session", "SandboxBashSession"),
|
||||
"SandboxBuilder": ("core.sandbox.builder", "SandboxBuilder"),
|
||||
"SandboxInitializeContext": ("core.sandbox.initializer", "SandboxInitializeContext"),
|
||||
"SandboxInitializer": ("core.sandbox.initializer", "SandboxInitializer"),
|
||||
"SandboxManager": ("core.sandbox.manager", "SandboxManager"),
|
||||
"SandboxProviderApiEntity": ("core.sandbox.entities", "SandboxProviderApiEntity"),
|
||||
"SandboxStorage": ("core.sandbox.storage", "SandboxStorage"),
|
||||
"SandboxType": ("core.sandbox.entities", "SandboxType"),
|
||||
"SyncSandboxInitializer": ("core.sandbox.initializer", "SyncSandboxInitializer"),
|
||||
"VMConfig": ("core.sandbox.builder", "VMConfig"),
|
||||
"create_sandbox_config_encrypter": ("core.sandbox.utils.encryption", "create_sandbox_config_encrypter"),
|
||||
"masked_config": ("core.sandbox.utils.encryption", "masked_config"),
|
||||
"sandbox_debug": ("core.sandbox.utils.debug", "sandbox_debug"),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name not in _LAZY_IMPORTS:
|
||||
raise AttributeError(f"module 'core.sandbox' has no attribute {name}")
|
||||
module_path, attr_name = _LAZY_IMPORTS[name]
|
||||
module = importlib.import_module(module_path)
|
||||
value = getattr(module, attr_name)
|
||||
globals()[name] = value
|
||||
return value
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user