Compare commits

..

49 Commits

Author SHA1 Message Date
Harry
f3761c26e9 Merge remote-tracking branch 'origin/main' into feat/llm-node-support-tools 2026-01-05 18:17:05 +08:00
-LAN-
a9e2c05a10 feat(graph-engine): add command to update variables at runtime (#30563)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 16:47:34 +08:00
-LAN-
6f8bd58e19 feat(graph-engine): make layer runtime state non-null and bound early (#30552)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 16:43:42 +08:00
scdeng
591ca05c84 feat(logstore): make graph field optional via env variable LOGSTORE… (#30554)
Co-authored-by: 阿永 <ayong.dy@alibaba-inc.com>
2026-01-05 16:12:41 +08:00
Stephen Zhou
a72044aa86 chore: fix lint in i18n (#30571) 2026-01-05 16:12:12 +08:00
Lework
34f3b288a7 chore(docker): update nltk data download process to include unstructured download_nltk_packages (#28876) 2026-01-05 15:50:33 +08:00
-LAN-
a99ac3fe0d refactor(models): Add mapped type hints to MessageAnnotation (#27751) 2026-01-05 15:50:03 +08:00
Stephen Zhou
52149c0d9b chore(web): add ESLint rules for i18n JSON validation (#30491) 2026-01-05 15:49:31 +08:00
wangxiaolei
631f999f65 refactor: use contains_any instead of Chaining where = where | f (#30559)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-05 15:48:31 +08:00
hsiong
be3ef9f050 fix: #30511 [Bug] knowledge_retrieval_node fails when using Rerank Model: "Working outside of application context" and add regression test (#30549)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-05 15:02:21 +08:00
dependabot[bot]
93a85ae98a chore(deps): bump @amplitude/analytics-browser from 2.31.4 to 2.33.1 in /web (#30538)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-05 15:05:04 +09:00
wangxiaolei
e3e19c437a fix: fix db env not work (#30541)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-01-05 11:10:45 +08:00
hsiong
693daea474 fix: INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH settings (#30463) 2026-01-05 11:10:04 +08:00
wangxiaolei
bc317a0009 feat: return data_source_info and data_source_detail_dict (#29912)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-05 11:04:03 +08:00
Kimi WANG
c158dfa198 fix: support to change NEXT_PUBLIC_BASE_PATH env using --build-arg in docker build (#29836)
Co-authored-by: root <root@KIMI-DESKTOP-01.mchrcloud.com>
2026-01-05 11:03:12 +08:00
Maries
79913590ae fix(api): surface subscription deletion errors to users (#30333) 2026-01-05 11:02:04 +08:00
wangxiaolei
f1fff0a243 fix: fix WorkflowExecution.outputs containing non-JSON-serializable o… (#30464) 2026-01-05 10:57:23 +08:00
hsiong
4bb08b93d7 chore: update dockerignore (#30460) 2026-01-05 10:55:14 +08:00
wangxiaolei
d0564ac63c feat: add flask command file-usage (#30500) 2026-01-05 10:52:21 +08:00
-LAN-
eb321ad614 chore: Add a new rule for import lint (#30526) 2026-01-05 10:48:14 +08:00
dependabot[bot]
7128d71cf7 chore(deps-dev): bump intersystems-irispython from 5.3.0 to 5.3.1 in /api (#30540)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-05 10:47:39 +08:00
-LAN-
95edbad1c7 refactor(workflow): add Jinja2 renderer abstraction for template transform (#30535)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-05 10:46:37 +08:00
-LAN-
154abdd915 chore: Update PR template lint command (#30533) 2026-01-05 10:46:01 +08:00
longbingljw
c58a093fd1 docs: update comments in docker/.env.example (#30516)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-01-04 21:52:03 +08:00
-LAN-
06ba40f016 refactor(code_node): implement DI for the code node (#30519) 2026-01-04 21:50:42 +08:00
sszaodian
2b838077e0 fix: when first setup after auto login error (#30523)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: maxin <maxin7@xiaomi.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
2026-01-04 20:24:49 +08:00
wangxiaolei
473f8ef29c feat: skip rerank if only one dataset is retrieved (#30075)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-04 20:22:51 +08:00
wangxiaolei
96736144b9 feat: enhance squid config (#30146) 2026-01-04 19:59:41 +08:00
yyh
f167e87146 refactor(web): align signup mail submit and tests (#30456) 2026-01-04 19:59:06 +08:00
Joel
a562089e48 feat: add frontend code review skills (#30520)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2026-01-04 19:57:09 +08:00
yyh
7d65d8048e feat: add Ralph Wiggum plugin support (#30525) 2026-01-04 19:56:02 +08:00
Coding On Star
c29cfd18f3 feat: revert model total credits (#30518) 2026-01-04 18:29:19 +08:00
Coding On Star
47b8e979e0 test: add unit tests for RagPipeline components (#30429)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-04 18:04:49 +08:00
-LAN-
83648feedf chore: upgrade fickling to 0.1.6 (#30495) 2026-01-04 17:22:12 +08:00
Asuka Minato
2cef879209 refactor: more ns.model to BaseModel (#30445) 2026-01-04 17:12:28 +08:00
github-actions[bot]
151101aaf5 chore(i18n): translate i18n files based on en-US changes (#30508)
Co-authored-by: hyoban <38493346+hyoban@users.noreply.github.com>
2026-01-04 17:11:40 +08:00
Stephen Zhou
9aaa08e19f ci: fix translate, allow manual dispatch (#30505)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-04 16:34:23 +08:00
zhsama
d4baf078f7 fix(plugins): enhance search to match name, label and description (#30501) 2026-01-04 16:07:04 +08:00
Coding On Star
84cbf0526d feat: model total credits (#26942)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-04 15:26:37 +08:00
Byron.wang
5362f69083 feat(refactoring): Support Structured Logging (JSON) (#30170)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-01-04 11:46:46 +08:00
yyh
822374eca5 chore: integrate @tanstack/eslint-plugin-query and fix service layer lint errors (#30444) 2026-01-04 11:20:06 +08:00
yyh
815ae6c754 chore: remove redundant web/app/page.module.css (#30482) 2026-01-04 10:22:36 +08:00
longbingljw
9a22baf57d feat: optimize for migration versions (#28787)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-01-03 21:33:20 +09:00
非法操作
c1bb310183 chore: remove icon_large of models (#30466)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: zhsama <torvalds@linux.do>
2026-01-03 02:35:17 +09:00
非法操作
8f2aabf7bd chore: Standardized the OpenAI icon (#30471) 2026-01-03 02:34:17 +09:00
zxhlyh
d132abcdb4 merge main 2025-12-29 15:55:45 +08:00
zxhlyh
d60348572e feat: llm node support tools 2025-12-29 14:55:26 +08:00
zxhlyh
0cff94d90e Merge branch 'main' into feat/llm-node-support-tools 2025-12-25 13:45:49 +08:00
zxhlyh
a7859de625 feat: llm node support tools 2025-12-24 14:15:55 +08:00
294 changed files with 31292 additions and 13492 deletions

View File

@@ -3,6 +3,7 @@
"feature-dev@claude-plugins-official": true,
"context7@claude-plugins-official": true,
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true
"pyright-lsp@claude-plugins-official": true,
"ralph-wiggum@claude-plugins-official": true
}
}

View File

@@ -0,0 +1,73 @@
---
name: frontend-code-review
description: "Trigger when the user requests a review of frontend files (e.g., `.tsx`, `.ts`, `.js`). Support both pending-change reviews and focused file reviews while applying the checklist rules."
---
# Frontend Code Review
## Intent
Use this skill whenever the user asks to review frontend code (especially `.tsx`, `.ts`, or `.js` files). Support two review modes:
1. **Pending-change review** inspect staged/working-tree files slated for commit and flag checklist violations before submission.
2. **File-targeted review** review the specific file(s) the user names and report the relevant checklist findings.
Stick to the checklist below for every applicable file and mode.
## Checklist
See [references/code-quality.md](references/code-quality.md), [references/performance.md](references/performance.md), [references/business-logic.md](references/business-logic.md) for the living checklist split by category—treat it as the canonical set of rules to follow.
Flag each rule violation with urgency metadata so future reviewers can prioritize fixes.
## Review Process
1. Open the relevant component/module. Gather lines that relate to class names, React Flow hooks, prop memoization, and styling.
2. For each rule in the review point, note where the code deviates and capture a representative snippet.
3. Compose the review section per the template below. Group violations first by **Urgent** flag, then by category order (Code Quality, Performance, Business Logic).
## Required output
When invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```
# Code review
Found <N> urgent issues need to be fixed:
## 1 <brief description of bug>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each urgent issue) ...
Found <M> suggestions for improvement:
## 1 <brief description of suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
### Suggested fix
<brief description of suggested fix>
---
... (repeat for each suggestion) ...
```
If there are no urgent issues, omit that section. If there are no suggestions, omit that section.
If the issue number is more than 10, summarize as "10+ urgent issues" or "10+ suggestions" and just output the first 10 issues.
Don't compress the blank lines between sections; keep them as-is for readability.
If you use Template A (i.e., there are issues to fix) and at least one issue requires code changes, append a brief follow-up question after the structured output asking whether the user wants you to apply the suggested fix(es). For example: "Would you like me to use the Suggested fix section to address these issues?"
### Template B (no issues)
```
## Code review
No issues found.
```

View File

@@ -0,0 +1,15 @@
# Rule Catalog — Business Logic
## Can't use workflowStore in Node components
IsUrgent: True
### Description
File path pattern of node components: `web/app/components/workflow/nodes/[nodeName]/node.tsx`
Node components are also used when creating a RAG Pipe from a template, but in that context there is no workflowStore Provider, which results in a blank screen. [This Issue](https://github.com/langgenius/dify/issues/29168) was caused by exactly this reason.
### Suggested Fix
Use `import { useNodes } from 'reactflow'` instead of `import useNodes from '@/app/components/workflow/store/workflow/use-nodes'`.

View File

@@ -0,0 +1,44 @@
# Rule Catalog — Code Quality
## Conditional class names use utility function
IsUrgent: True
Category: Code Quality
### Description
Ensure conditional CSS is handled via the shared `classNames` instead of custom ternaries, string concatenation, or template strings. Centralizing class logic keeps components consistent and easier to maintain.
### Suggested Fix
```ts
import { cn } from '@/utils/classnames'
const classNames = cn(isActive ? 'text-primary-600' : 'text-gray-500')
```
## Tailwind-first styling
IsUrgent: True
Category: Code Quality
### Description
Favor Tailwind CSS utility classes instead of adding new `.module.css` files unless a Tailwind combination cannot achieve the required styling. Keeping styles in Tailwind improves consistency and reduces maintenance overhead.
Update this file when adding, editing, or removing Code Quality rules so the catalog remains accurate.
## Classname ordering for easy overrides
### Description
When writing components, always place the incoming `className` prop after the components own class values so that downstream consumers can override or extend the styling. This keeps your components defaults but still lets external callers change or remove specific styles.
Example:
```tsx
import { cn } from '@/utils/classnames'
const Button = ({ className }) => {
return <div className={cn('bg-primary-600', className)}></div>
}
```

View File

@@ -0,0 +1,45 @@
# Rule Catalog — Performance
## React Flow data usage
IsUrgent: True
Category: Performance
### Description
When rendering React Flow, prefer `useNodes`/`useEdges` for UI consumption and rely on `useStoreApi` inside callbacks that mutate or read node/edge state. Avoid manually pulling Flow data outside of these hooks.
## Complex prop memoization
IsUrgent: True
Category: Performance
### Description
Wrap complex prop values (objects, arrays, maps) in `useMemo` prior to passing them into child components to guarantee stable references and prevent unnecessary renders.
Update this file when adding, editing, or removing Performance rules so the catalog remains accurate.
Wrong:
```tsx
<HeavyComp
config={{
provider: ...,
detail: ...
}}
/>
```
Right:
```tsx
const config = useMemo(() => ({
provider: ...,
detail: ...
}), [provider, detail]);
<HeavyComp
config={config}
/>
```

View File

@@ -20,4 +20,4 @@
- [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!)
- [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change.
- [x] I've updated the documentation accordingly.
- [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods
- [x] I ran `make lint` and `make type-check` (backend) and `cd web && npx lint-staged` (frontend) to appease the lint gods

View File

@@ -5,6 +5,7 @@ on:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
@@ -18,7 +19,8 @@ jobs:
run:
working-directory: web
steps:
- uses: actions/checkout@v6
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@@ -26,21 +28,28 @@ jobs:
- name: Check for file changes in i18n/en-US
id: check_files
run: |
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
fi
- name: Install pnpm

1
.gitignore vendored
View File

@@ -235,3 +235,4 @@ scripts/stress-test/reports/
# settings
*.local.json
*.local.md

View File

@@ -60,9 +60,10 @@ check:
@echo "✅ Code check complete"
lint:
@echo "🔧 Running ruff format, check with fixes, and import linter..."
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
type-check:
@@ -122,7 +123,7 @@ help:
@echo "Backend Code Quality:"
@echo " make format - Format code with ruff"
@echo " make check - Check code with ruff"
@echo " make lint - Format and fix code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo ""

View File

@@ -502,6 +502,8 @@ LOG_FILE_BACKUP_COUNT=5
LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S
# Log Timezone
LOG_TZ=UTC
# Log output format: text or json
LOG_OUTPUT_FORMAT=text
# Log format
LOG_FORMAT=%(asctime)s,%(msecs)d %(levelname)-2s [%(filename)s:%(lineno)d] %(req_id)s %(message)s
@@ -573,6 +575,10 @@ LOGSTORE_DUAL_WRITE_ENABLED=false
# Enable dual-read fallback to SQL database when LogStore returns no results (default: true)
# Useful for migration scenarios where historical data exists only in SQL database
LOGSTORE_DUAL_READ_ENABLED=true
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
LOGSTORE_ENABLE_PUT_GRAPH_FIELD=true
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1

View File

@@ -3,9 +3,11 @@ root_packages =
core
configs
controllers
extensions
models
tasks
services
include_external_packages = True
[importlinter:contract:workflow]
name = Workflow
@@ -33,6 +35,29 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
type = forbidden
source_modules =
core.workflow
forbidden_modules =
extensions.ext_database
extensions.ext_redis
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.variable_assigner.common.impl -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@@ -79,7 +79,8 @@ COPY --from=packages --chown=dify:dify ${VIRTUAL_ENV} ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@@ -2,9 +2,11 @@ import logging
import time
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
logger = logging.getLogger(__name__)
@@ -25,28 +27,35 @@ def create_flask_app_with_configs() -> DifyApp:
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
# Initialize logging context for this request
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# add after request hook for injecting X-Trace-Id header from OpenTelemetry span context
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request
def add_trace_id_header(response):
def add_trace_headers(response):
try:
span = get_current_span()
ctx = span.get_span_context() if span else None
if ctx and ctx.is_valid:
trace_id_hex = format(ctx.trace_id, "032x")
# Avoid duplicates if some middleware added it
if "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = trace_id_hex
if not ctx or not ctx.is_valid:
return response
# Inject trace headers from OTEL context
if ctx.trace_id != INVALID_TRACE_ID and "X-Trace-Id" not in response.headers:
response.headers["X-Trace-Id"] = format(ctx.trace_id, "032x")
if ctx.span_id != INVALID_SPAN_ID and "X-Span-Id" not in response.headers:
response.headers["X-Span-Id"] = format(ctx.span_id, "016x")
except Exception:
# Never break the response due to tracing header injection
logger.warning("Failed to add trace ID to response header", exc_info=True)
logger.warning("Failed to add trace headers to response", exc_info=True)
return response
# Capture the decorator's return value to avoid pyright reportUnusedFunction
_ = before_request
_ = add_trace_id_header
_ = add_trace_headers
return dify_app

View File

@@ -235,7 +235,7 @@ def migrate_annotation_vector_database():
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question,
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id},
)
documents.append(document)
@@ -1184,6 +1184,217 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
@click.command("file-usage", help="Query file usages and show where files are referenced.")
@click.option("--file-id", type=str, default=None, help="Filter by file UUID.")
@click.option("--key", type=str, default=None, help="Filter by storage key.")
@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').")
@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).")
@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).")
@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.")
def file_usage(
file_id: str | None,
key: str | None,
src: str | None,
limit: int,
offset: int,
output_json: bool,
):
"""
Query file usages and show where files are referenced in the database.
This command reuses the same reference checking logic as clear-orphaned-file-records
and displays detailed information about where each file is referenced.
"""
# define tables and columns to process
files_tables = [
{"table": "upload_files", "id_column": "id", "key_column": "key"},
{"table": "tool_files", "id_column": "id", "key_column": "file_key"},
]
ids_tables = [
{"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"},
{"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"},
{"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"},
{"type": "text", "table": "messages", "column": "answer", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"},
{"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"},
{"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"},
{"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"},
{"type": "text", "table": "apps", "column": "icon", "pk_column": "id"},
{"type": "text", "table": "sites", "column": "icon", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"},
{"type": "json", "table": "messages", "column": "message", "pk_column": "id"},
]
# Stream file usages with pagination to avoid holding all results in memory
paginated_usages = []
total_count = 0
# First, build a mapping of file_id -> storage_key from the base tables
file_key_map = {}
for files_table in files_tables:
query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}"
# If filtering by key or file_id, verify it exists
if file_id and file_id not in file_key_map:
if output_json:
click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"}))
else:
click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red"))
return
if key:
valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"}
matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes]
if not matching_file_ids:
if output_json:
click.echo(json.dumps({"error": f"Key {key} not found in base tables"}))
else:
click.echo(click.style(f"Key {key} not found in base tables.", fg="red"))
return
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
# For each reference table/column, find matching file IDs and record the references
for ids_table in ids_tables:
src_filter = f"{ids_table['table']}.{ids_table['column']}"
# Skip if src filter doesn't match (use fnmatch for wildcard patterns)
if src:
if "%" in src or "_" in src:
import fnmatch
# Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?)
pattern = src.replace("%", "*").replace("_", "?")
if not fnmatch.fnmatch(src_filter, pattern):
continue
else:
if src_filter != src:
continue
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
# Output results
if output_json:
result = {
"total": total_count,
"offset": offset,
"limit": limit,
"usages": paginated_usages,
}
click.echo(json.dumps(result, indent=2))
else:
click.echo(
click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white")
)
click.echo("")
if not paginated_usages:
click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow"))
return
# Print table header
click.echo(
click.style(
f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}",
fg="cyan",
)
)
click.echo(click.style("-" * 190, fg="white"))
# Print each usage
for usage in paginated_usages:
click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}")
# Show pagination info
if offset + limit < total_count:
click.echo("")
click.echo(
click.style(
f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white"
)
)
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")

View File

@@ -587,6 +587,11 @@ class LoggingConfig(BaseSettings):
default="INFO",
)
LOG_OUTPUT_FORMAT: Literal["text", "json"] = Field(
description="Log output format: 'text' for human-readable, 'json' for structured JSON logs.",
default="text",
)
LOG_FILE: str | None = Field(
description="File path for log output.",
default=None,

View File

@@ -16,7 +16,6 @@ class MilvusConfig(BaseSettings):
description="Authentication token for Milvus, if token-based authentication is enabled",
default=None,
)
MILVUS_USER: str | None = Field(
description="Username for authenticating with Milvus, if username/password authentication is enabled",
default=None,

View File

@@ -13,7 +13,6 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import MessageTextField
from fields.raws import FilesContainedField
from libs.datetime_utils import naive_utc_now, parse_time_range
from libs.helper import TimestampField
@@ -177,6 +176,12 @@ annotation_hit_history_model = console_ns.model(
},
)
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
# Simple message detail model
simple_message_detail_model = console_ns.model(
"SimpleMessageDetail",

View File

@@ -1,13 +1,9 @@
import logging
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@@ -22,7 +18,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
@@ -60,30 +55,6 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class PreviousWorkflow(BaseModel):
"""Previous workflow attempt for regeneration context."""
nodes: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated nodes")
edges: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated edges")
warnings: list[str] = Field(default_factory=list, description="Warnings from previous generation")
class FlowchartGeneratePayload(BaseModel):
instruction: str = Field(..., description="Workflow flowchart generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
available_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Available node types")
existing_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow nodes")
existing_edges: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow edges")
available_tools: list[dict[str, Any]] = Field(default_factory=list, description="Available tools")
selected_node_ids: list[str] = Field(default_factory=list, description="IDs of selected nodes for context")
previous_workflow: PreviousWorkflow | None = Field(default=None, description="Previous workflow for regeneration")
regenerate_mode: bool = Field(default=False, description="Whether this is a regeneration request")
# Language preference for generated content (node titles, descriptions)
language: str | None = Field(default=None, description="Preferred language for generated content")
# Available models that user has configured (for LLM/question-classifier nodes)
available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -93,7 +64,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(FlowchartGeneratePayload)
@console_ns.route("/rule-generate")
@@ -285,52 +255,6 @@ class InstructionGenerateApi(Resource):
raise CompletionRequestError(e.description)
@console_ns.route("/flowchart-generate")
class FlowchartGenerateApi(Resource):
@console_ns.doc("generate_workflow_flowchart")
@console_ns.doc(description="Generate workflow flowchart using LLM with intent classification")
@console_ns.expect(console_ns.models[FlowchartGeneratePayload.__name__])
@console_ns.response(200, "Flowchart generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = FlowchartGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
# Convert PreviousWorkflow to dict if present
previous_workflow_dict = args.previous_workflow.model_dump() if args.previous_workflow else None
result = WorkflowGenerator.generate_workflow_flowchart(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
available_nodes=args.available_nodes,
existing_nodes=args.existing_nodes,
existing_edges=args.existing_edges,
available_tools=args.available_tools,
selected_node_ids=args.selected_node_ids,
previous_workflow=previous_workflow_dict,
regenerate_mode=args.regenerate_mode,
preferred_language=args.language,
available_models=args.available_models,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return result
@console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template")

View File

@@ -751,12 +751,12 @@ class DocumentApi(DocumentResource):
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
@@ -784,12 +784,12 @@ class DocumentApi(DocumentResource):
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
"position": document.position,
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"data_source_info": document.data_source_info_dict,
"data_source_detail_dict": document.data_source_detail_dict,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,

View File

@@ -1,8 +1,7 @@
from typing import Any
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -11,7 +10,11 @@ from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from libs.helper import UUIDStrOrEmpty
from libs.login import current_user
from models import Account
@@ -49,7 +52,6 @@ register_schema_models(console_ns, ConversationListQuery, ConversationRenamePayl
endpoint="installed_app_conversations",
)
class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[ConversationListQuery.__name__])
def get(self, installed_app):
app_model = installed_app.app
@@ -73,7 +75,7 @@ class ConversationListApi(InstalledAppResource):
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
@@ -82,6 +84,13 @@ class ConversationListApi(InstalledAppResource):
invoke_from=InvokeFrom.EXPLORE,
pinned=args.pinned,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -105,7 +114,7 @@ class ConversationApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@console_ns.route(
@@ -113,7 +122,6 @@ class ConversationApi(InstalledAppResource):
endpoint="installed_app_conversation_rename",
)
class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields)
@console_ns.expect(console_ns.models[ConversationRenamePayload.__name__])
def post(self, installed_app, c_id):
app_model = installed_app.app
@@ -128,9 +136,14 @@ class ConversationRenameApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
return ConversationService.rename(
conversation = ConversationService.rename(
app_model, conversation_id, current_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@@ -155,7 +168,7 @@ class ConversationPinApi(InstalledAppResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -174,4 +187,4 @@ class ConversationUnPinApi(InstalledAppResource):
raise ValueError("current_user must be an Account instance")
WebConversationService.unpin(app_model, conversation_id, current_user)
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")

View File

@@ -2,8 +2,7 @@ import logging
from typing import Literal
from flask import request
from flask_restx import marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -23,7 +22,8 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.message_fields import message_infinite_scroll_pagination_fields
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
@@ -66,7 +66,6 @@ register_schema_models(console_ns, MessageListQuery, MessageFeedbackPayload, Mor
endpoint="installed_app_messages",
)
class MessageListApi(InstalledAppResource):
@marshal_with(message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[MessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -78,13 +77,20 @@ class MessageListApi(InstalledAppResource):
args = MessageListQuery.model_validate(request.args.to_dict())
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model,
current_user,
str(args.conversation_id),
str(args.first_id) if args.first_id else None,
args.limit,
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -116,7 +122,7 @@ class MessageFeedbackApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -201,4 +207,4 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")

View File

@@ -1,14 +1,14 @@
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, UUIDStrOrEmpty
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@@ -26,28 +26,8 @@ class SavedMessageCreatePayload(BaseModel):
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@marshal_with(saved_message_infinite_scroll_pagination_fields)
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@@ -57,12 +37,19 @@ class SavedMessageListApi(InstalledAppResource):
args = SavedMessageListQuery.model_validate(request.args.to_dict())
return SavedMessageService.pagination_by_last_id(
pagination = SavedMessageService.pagination_by_last_id(
app_model,
current_user,
str(args.last_id) if args.last_id else None,
args.limit,
)
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
def post(self, installed_app):
@@ -78,7 +65,7 @@ class SavedMessageListApi(InstalledAppResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@console_ns.route(
@@ -96,4 +83,4 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@@ -4,12 +4,11 @@ from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -44,6 +43,12 @@ class TriggerSubscriptionUpdateRequest(BaseModel):
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
@@ -333,7 +338,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@@ -345,50 +350,32 @@ class TriggerSubscriptionUpdateApi(Resource):
provider_id = TriggerProviderID(subscription.provider_id)
try:
# rename only
if (
args.name is not None
and args.credentials is None
and args.parameters is None
and args.properties is None
):
# For rename only, just update the name
rename = request.name is not None and not any((request.credentials, request.parameters, request.properties))
# When credential type is UNAUTHORIZED, it indicates the subscription was manually created
# For Manually created subscription, they dont have credentials, parameters
# They only have name and properties(which is input by user)
manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED
if rename or manually_created:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
name=request.name,
properties=request.properties,
)
return 200
# rebuild for create automatically by the provider
match subscription.credential_type:
case CredentialType.UNAUTHORIZED:
TriggerProviderService.update_trigger_subscription(
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
name=args.name,
properties=args.properties,
)
return 200
case CredentialType.API_KEY | CredentialType.OAUTH2:
if args.credentials:
new_credentials: dict[str, Any] = {
key: value if value != HIDDEN_VALUE else subscription.credentials.get(key, UNKNOWN_VALUE)
for key, value in args.credentials.items()
}
else:
new_credentials = subscription.credentials
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=args.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=new_credentials,
parameters=args.parameters or subscription.parameters,
)
return 200
case _:
raise BadRequest("Invalid credential type")
# For the rest cases(API_KEY, OAUTH2)
# we need to call third party provider(e.g. GitHub) to rebuild the subscription
TriggerProviderService.rebuild_trigger_subscription(
tenant_id=user.current_tenant_id,
name=request.name,
provider_id=provider_id,
subscription_id=subscription_id,
credentials=request.credentials or subscription.credentials,
parameters=request.parameters or subscription.parameters,
)
return 200
except ValueError as e:
raise BadRequest(str(e))
except Exception as e:

View File

@@ -3,8 +3,7 @@ from uuid import UUID
from flask import request
from flask_restx import Resource
from flask_restx._http import HTTPStatus
from pydantic import BaseModel, Field, field_validator, model_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
@@ -16,9 +15,9 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
build_conversation_delete_model,
build_conversation_infinite_scroll_pagination_model,
build_simple_conversation_model,
ConversationDelete,
ConversationInfiniteScrollPagination,
SimpleConversation,
)
from fields.conversation_variable_fields import (
build_conversation_variable_infinite_scroll_pagination_model,
@@ -105,7 +104,6 @@ class ConversationApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_conversation_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List all conversations for the current user.
@@ -120,7 +118,7 @@ class ConversationApi(Resource):
try:
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@@ -129,6 +127,13 @@ class ConversationApi(Resource):
invoke_from=InvokeFrom.SERVICE_API,
sort_by=query_args.sort_by,
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -146,7 +151,6 @@ class ConversationDetailApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_conversation_delete_model(service_api_ns), code=HTTPStatus.NO_CONTENT)
def delete(self, app_model: App, end_user: EndUser, c_id):
"""Delete a specific conversation."""
app_mode = AppMode.value_of(app_model.mode)
@@ -159,7 +163,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ConversationDelete(result="success").model_dump(mode="json"), 204
@service_api_ns.route("/conversations/<uuid:c_id>/name")
@@ -176,7 +180,6 @@ class ConversationRenameApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@service_api_ns.marshal_with(build_simple_conversation_model(service_api_ns))
def post(self, app_model: App, end_user: EndUser, c_id):
"""Rename a conversation or auto-generate a name."""
app_mode = AppMode.value_of(app_model.mode)
@@ -188,7 +191,14 @@ class ConversationRenameApi(Resource):
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
try:
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
conversation = ConversationService.rename(
app_model, conversation_id, end_user, payload.name, payload.auto_generate
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -1,11 +1,10 @@
import json
import logging
from typing import Literal
from uuid import UUID
from flask import request
from flask_restx import Namespace, Resource, fields
from pydantic import BaseModel, Field
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
@@ -14,10 +13,8 @@ from controllers.service_api import service_api_ns
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import build_message_file_model
from fields.message_fields import build_agent_thought_model, build_feedback_model
from fields.raws import FilesContainedField
from libs.helper import TimestampField
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -48,49 +45,6 @@ class FeedbackListQuery(BaseModel):
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
def build_message_model(api_or_ns: Namespace):
"""Build the message model for the API or Namespace."""
# First build the nested models
feedback_model = build_feedback_model(api_or_ns)
agent_thought_model = build_agent_thought_model(api_or_ns)
message_file_model = build_message_file_model(api_or_ns)
# Then build the message fields with nested models
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_model)),
"feedback": fields.Nested(feedback_model, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.Raw(
attribute=lambda obj: json.loads(obj.message_metadata).get("retriever_resources", [])
if obj.message_metadata
else []
),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"status": fields.String,
"error": fields.String,
}
return api_or_ns.model("Message", message_fields)
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the message infinite scroll pagination model for the API or Namespace."""
# Build the nested message model first
message_model = build_message_model(api_or_ns)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_model)),
}
return api_or_ns.model("MessageInfiniteScrollPagination", message_infinite_scroll_pagination_fields)
@service_api_ns.route("/messages")
class MessageListApi(Resource):
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
@@ -104,7 +58,6 @@ class MessageListApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
@service_api_ns.marshal_with(build_message_infinite_scroll_pagination_model(service_api_ns))
def get(self, app_model: App, end_user: EndUser):
"""List messages in a conversation.
@@ -119,9 +72,16 @@ class MessageListApi(Resource):
first_id = str(query_args.first_id) if query_args.first_id else None
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model, end_user, conversation_id, first_id, query_args.limit
)
adapter = TypeAdapter(MessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return MessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -162,7 +122,7 @@ class MessageFeedbackApi(Resource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@service_api_ns.route("/app/feedbacks")

View File

@@ -1,5 +1,6 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -8,7 +9,11 @@ from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from fields.conversation_fields import (
ConversationInfiniteScrollPagination,
ResultResponse,
SimpleConversation,
)
from libs.helper import uuid_value
from models.model import AppMode
from services.conversation_service import ConversationService
@@ -54,7 +59,6 @@ class ConversationListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(conversation_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -82,7 +86,7 @@ class ConversationListApi(WebApiResource):
try:
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
@@ -92,16 +96,19 @@ class ConversationListApi(WebApiResource):
pinned=pinned,
sort_by=args["sort_by"],
)
adapter = TypeAdapter(SimpleConversation)
conversations = [adapter.validate_python(item, from_attributes=True) for item in pagination.data]
return ConversationInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=conversations,
).model_dump(mode="json")
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>")
class ConversationApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Conversation")
@web_ns.doc(description="Delete a specific conversation.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -115,7 +122,6 @@ class ConversationApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -126,7 +132,7 @@ class ConversationApi(WebApiResource):
ConversationService.delete(app_model, conversation_id, end_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204
@web_ns.route("/conversations/<uuid:c_id>/name")
@@ -155,7 +161,6 @@ class ConversationRenameApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(simple_conversation_fields)
def post(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -171,17 +176,20 @@ class ConversationRenameApi(WebApiResource):
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
conversation = ConversationService.rename(
app_model, conversation_id, end_user, args["name"], args["auto_generate"]
)
return (
TypeAdapter(SimpleConversation)
.validate_python(conversation, from_attributes=True)
.model_dump(mode="json")
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@web_ns.route("/conversations/<uuid:c_id>/pin")
class ConversationPinApi(WebApiResource):
pin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Pin Conversation")
@web_ns.doc(description="Pin a specific conversation to keep it at the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -195,7 +203,6 @@ class ConversationPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(pin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -208,15 +215,11 @@ class ConversationPinApi(WebApiResource):
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/conversations/<uuid:c_id>/unpin")
class ConversationUnPinApi(WebApiResource):
unpin_response_fields = {
"result": fields.String,
}
@web_ns.doc("Unpin Conversation")
@web_ns.doc(description="Unpin a specific conversation to remove it from the top of the list.")
@web_ns.doc(params={"c_id": {"description": "Conversation UUID", "type": "string", "required": True}})
@@ -230,7 +233,6 @@ class ConversationUnPinApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(unpin_response_fields)
def patch(self, app_model, end_user, c_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -239,4 +241,4 @@ class ConversationUnPinApi(WebApiResource):
conversation_id = str(c_id)
WebConversationService.unpin(app_model, conversation_id, end_user)
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")

View File

@@ -2,8 +2,7 @@ import logging
from typing import Literal
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -22,11 +21,10 @@ from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from fields.conversation_fields import message_file_fields
from fields.message_fields import agent_thought_fields, feedback_fields, retriever_resource_fields
from fields.raws import FilesContainedField
from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import TimestampField, uuid_value
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -70,29 +68,6 @@ register_schema_models(web_ns, MessageListQuery, MessageFeedbackPayload, Message
@web_ns.route("/messages")
class MessageListApi(WebApiResource):
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
}
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@web_ns.doc("Get Message List")
@web_ns.doc(description="Retrieve paginated list of messages from a conversation in a chat application.")
@web_ns.doc(
@@ -121,7 +96,6 @@ class MessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -131,9 +105,16 @@ class MessageListApi(WebApiResource):
query = MessageListQuery.model_validate(raw_args)
try:
return MessageService.pagination_by_first_id(
pagination = MessageService.pagination_by_first_id(
app_model, end_user, query.conversation_id, query.first_id, query.limit
)
adapter = TypeAdapter(WebMessageListItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return WebMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except FirstMessageNotExistsError:
@@ -142,10 +123,6 @@ class MessageListApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/feedbacks")
class MessageFeedbackApi(WebApiResource):
feedback_response_fields = {
"result": fields.String,
}
@web_ns.doc("Create Message Feedback")
@web_ns.doc(description="Submit feedback (like/dislike) for a specific message.")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -170,7 +147,6 @@ class MessageFeedbackApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(feedback_response_fields)
def post(self, app_model, end_user, message_id):
message_id = str(message_id)
@@ -187,7 +163,7 @@ class MessageFeedbackApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/messages/<uuid:message_id>/more-like-this")
@@ -247,10 +223,6 @@ class MessageMoreLikeThisApi(WebApiResource):
@web_ns.route("/messages/<uuid:message_id>/suggested-questions")
class MessageSuggestedQuestionApi(WebApiResource):
suggested_questions_response_fields = {
"data": fields.List(fields.String),
}
@web_ns.doc("Get Suggested Questions")
@web_ns.doc(description="Get suggested follow-up questions after a message (chat apps only).")
@web_ns.doc(params={"message_id": {"description": "Message UUID", "type": "string", "required": True}})
@@ -264,7 +236,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(suggested_questions_response_fields)
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
@@ -277,7 +248,6 @@ class MessageSuggestedQuestionApi(WebApiResource):
app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP
)
# questions is a list of strings, not a list of Message objects
# so we can directly return it
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
@@ -296,4 +266,4 @@ class MessageSuggestedQuestionApi(WebApiResource):
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
return SuggestedQuestionsResponse(data=questions).model_dump(mode="json")

View File

@@ -1,40 +1,20 @@
from flask_restx import fields, marshal_with, reqparse
from flask_restx import reqparse
from flask_restx.inputs import int_range
from pydantic import TypeAdapter
from werkzeug.exceptions import NotFound
from controllers.web import web_ns
from controllers.web.error import NotCompletionAppError
from controllers.web.wraps import WebApiResource
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField, uuid_value
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.helper import uuid_value
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
feedback_fields = {"rating": fields.String}
message_fields = {
"id": fields.String,
"inputs": fields.Raw,
"query": fields.String,
"answer": fields.String,
"message_files": fields.List(fields.Nested(message_file_fields)),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"created_at": TimestampField,
}
@web_ns.route("/saved-messages")
class SavedMessageListApi(WebApiResource):
saved_message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
post_response_fields = {
"result": fields.String,
}
@web_ns.doc("Get Saved Messages")
@web_ns.doc(description="Retrieve paginated list of saved messages for a completion application.")
@web_ns.doc(
@@ -58,7 +38,6 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(saved_message_infinite_scroll_pagination_fields)
def get(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@@ -70,7 +49,14 @@ class SavedMessageListApi(WebApiResource):
)
args = parser.parse_args()
return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
pagination = SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"])
adapter = TypeAdapter(SavedMessageItem)
items = [adapter.validate_python(message, from_attributes=True) for message in pagination.data]
return SavedMessageInfiniteScrollPagination(
limit=pagination.limit,
has_more=pagination.has_more,
data=items,
).model_dump(mode="json")
@web_ns.doc("Save Message")
@web_ns.doc(description="Save a specific message for later reference.")
@@ -89,7 +75,6 @@ class SavedMessageListApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(post_response_fields)
def post(self, app_model, end_user):
if app_model.mode != "completion":
raise NotCompletionAppError()
@@ -102,15 +87,11 @@ class SavedMessageListApi(WebApiResource):
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {"result": "success"}
return ResultResponse(result="success").model_dump(mode="json")
@web_ns.route("/saved-messages/<uuid:message_id>")
class SavedMessageApi(WebApiResource):
delete_response_fields = {
"result": fields.String,
}
@web_ns.doc("Delete Saved Message")
@web_ns.doc(description="Remove a message from saved messages.")
@web_ns.doc(params={"message_id": {"description": "Message UUID to delete", "type": "string", "required": True}})
@@ -124,7 +105,6 @@ class SavedMessageApi(WebApiResource):
500: "Internal Server Error",
}
)
@marshal_with(delete_response_fields)
def delete(self, app_model, end_user, message_id):
message_id = str(message_id)
@@ -133,4 +113,4 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
return {"result": "success"}, 204
return ResultResponse(result="success").model_dump(mode="json"), 204

View File

@@ -75,7 +75,7 @@ class AnnotationReplyFeature:
AppAnnotationService.add_annotation_history(
annotation.id,
app_record.id,
annotation.question,
annotation.question_text,
annotation.content,
query,
user_id,

View File

@@ -66,6 +66,7 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
@@ -98,8 +99,6 @@ class PauseStatePersistenceLayer(GraphEngineLayer):
if not isinstance(event, GraphRunPausedEvent):
return
assert self.graph_runtime_state is not None
entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)

View File

@@ -33,6 +33,7 @@ class TriggerPostLayer(GraphEngineLayer):
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
@@ -57,10 +58,6 @@ class TriggerPostLayer(GraphEngineLayer):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()
# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return
outputs = self.graph_runtime_state.outputs
# BASICLY, workflow_execution_id is the same as workflow_run_id

View File

@@ -30,7 +30,6 @@ class SimpleModelProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity):
@@ -44,7 +43,6 @@ class SimpleModelProviderEntity(BaseModel):
label=provider_entity.label,
icon_small=provider_entity.icon_small,
icon_small_dark=provider_entity.icon_small_dark,
icon_large=provider_entity.icon_large,
supported_model_types=provider_entity.supported_model_types,
)
@@ -94,7 +92,6 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@@ -88,7 +88,41 @@ def _get_user_provided_host_header(headers: dict | None) -> str | None:
return None
def _inject_trace_headers(headers: dict | None) -> dict:
"""
Inject W3C traceparent header for distributed tracing.
When OTEL is enabled, HTTPXClientInstrumentor handles trace propagation automatically.
When OTEL is disabled, we manually inject the traceparent header.
"""
if headers is None:
headers = {}
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return headers
# Skip if OTEL is enabled - HTTPXClientInstrumentor handles this automatically
if dify_config.ENABLE_OTEL:
return headers
# Generate and inject traceparent for non-OTEL scenarios
try:
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
except Exception:
# Silently ignore errors to avoid breaking requests
logger.debug("Failed to generate traceparent header", exc_info=True)
return headers
def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# Convert requests-style allow_redirects to httpx-style follow_redirects
if "allow_redirects" in kwargs:
allow_redirects = kwargs.pop("allow_redirects")
if "follow_redirects" not in kwargs:
@@ -106,27 +140,26 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Extract follow_redirects for client.send() - it's not a build_request parameter
follow_redirects = kwargs.pop("follow_redirects", True)
# Inject traceparent header for distributed tracing (when OTEL is not enabled)
headers = kwargs.get("headers") or {}
headers = _inject_trace_headers(headers)
kwargs["headers"] = headers
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
headers = kwargs.get("headers", {})
user_provided_host = _get_user_provided_host_header(headers)
retries = 0
while retries <= max_retries:
try:
# Build the request manually to preserve the Host header
# httpx may override the Host header when using a proxy, so we use
# the request API to explicitly set headers before sending
# Preserve the user-provided Host header
# httpx may override the Host header when using a proxy
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["Host"] = user_provided_host
request = client.build_request(method, url, headers=headers, **kwargs)
response = client.send(request, follow_redirects=follow_redirects)
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):

View File

@@ -103,3 +103,60 @@ def parse_traceparent_header(traceparent: str) -> str | None:
if len(parts) == 4 and len(parts[1]) == 32:
return parts[1]
return None
def get_span_id_from_otel_context() -> str | None:
"""
Retrieve the current span ID from the active OpenTelemetry trace context.
Returns:
A 16-character hex string representing the span ID, or None if not available.
"""
try:
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID
span = get_current_span()
if not span:
return None
span_context = span.get_span_context()
if not span_context or span_context.span_id == INVALID_SPAN_ID:
return None
return f"{span_context.span_id:016x}"
except Exception:
return None
def generate_traceparent_header() -> str | None:
"""
Generate a W3C traceparent header from the current context.
Uses OpenTelemetry context if available, otherwise uses the
ContextVar-based trace_id from the logging context.
Format: {version}-{trace_id}-{span_id}-{flags}
Example: 00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01
Returns:
A valid traceparent header string, or None if generation fails.
"""
import uuid
# Try OTEL context first
trace_id = get_trace_id_from_otel_context()
span_id = get_span_id_from_otel_context()
if trace_id and span_id:
return f"00-{trace_id}-{span_id}-01"
# Fallback: use ContextVar-based trace_id or generate new one
from core.logging.context import get_trace_id as get_logging_trace_id
trace_id = get_logging_trace_id() or uuid.uuid4().hex
# Generate a new span_id (16 hex chars)
span_id = uuid.uuid4().hex[:16]
return f"00-{trace_id}-{span_id}-01"

View File

@@ -1,5 +1,6 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, cast
@@ -11,6 +12,8 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SUGGESTED_QUESTIONS_MAX_TOKENS,
SUGGESTED_QUESTIONS_TEMPERATURE,
@@ -27,7 +30,6 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@@ -283,35 +285,6 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
):
return WorkflowGenerator.generate_workflow_flowchart(
tenant_id=tenant_id,
instruction=instruction,
model_config=model_config,
available_nodes=available_nodes,
existing_nodes=existing_nodes,
available_tools=available_tools,
selected_node_ids=selected_node_ids,
previous_workflow=previous_workflow,
regenerate_mode=regenerate_mode,
preferred_language=preferred_language,
available_models=available_models,
)
@classmethod
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":

View File

@@ -143,50 +143,6 @@ Based on task description, please create a well-structured prompt template that
Please generate the full prompt template with at least 300 words and output only the prompt template.
""" # noqa: E501
WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """
You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request.
Constraints:
- Detect the language of the user's request. Generate all node titles in the same language as the user's input.
- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language.
- Use only node types listed in <available_nodes>.
- Use only tools listed in <available_tools>. When using a tool node, set type=tool and tool=<tool_key>.
- Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key.
- Prefer reusing node titles from <existing_nodes> when possible.
- Output must be valid Mermaid flowchart syntax, no markdown, no extra text.
- First line must be: flowchart LR
- Every node must be declared on its own line using:
<id>["type=<type>|title=<title>|tool=<tool_key>"]
- type is required and must match a type in <available_nodes>.
- title is required for non-tool nodes.
- tool is required only when type=tool, otherwise omit tool.
- Declare all node lines before any edges.
- Edges must use:
<id> --> <id>
<id> -->|true| <id>
<id> -->|false| <id>
- Keep node ids unique and simple (N1, N2, ...).
- For complex orchestration:
- Break the request into stages (ingest, transform, decision, action, output).
- Use IfElse for branching and label edges true/false only.
- Fan-in branches by connecting multiple nodes into a shared downstream node.
- Avoid cycles unless explicitly requested.
- Keep each branch complete with a clear downstream target.
<user_request>
{{TASK_DESCRIPTION}}
</user_request>
<available_nodes>
{{AVAILABLE_NODES}}
</available_nodes>
<existing_nodes>
{{EXISTING_NODES}}
</existing_nodes>
<available_tools>
{{AVAILABLE_TOOLS}}
</available_tools>
"""
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
Here is a task description for which I would like you to create a high-quality prompt template for:
<task_description>

View File

@@ -0,0 +1,20 @@
"""Structured logging components for Dify."""
from core.logging.context import (
clear_request_context,
get_request_id,
get_trace_id,
init_request_context,
)
from core.logging.filters import IdentityContextFilter, TraceContextFilter
from core.logging.structured_formatter import StructuredJSONFormatter
__all__ = [
"IdentityContextFilter",
"StructuredJSONFormatter",
"TraceContextFilter",
"clear_request_context",
"get_request_id",
"get_trace_id",
"init_request_context",
]

View File

@@ -0,0 +1,35 @@
"""Request context for logging - framework agnostic.
This module provides request-scoped context variables for logging,
using Python's contextvars for thread-safe and async-safe storage.
"""
import uuid
from contextvars import ContextVar
_request_id: ContextVar[str] = ContextVar("log_request_id", default="")
_trace_id: ContextVar[str] = ContextVar("log_trace_id", default="")
def get_request_id() -> str:
"""Get current request ID (10 hex chars)."""
return _request_id.get()
def get_trace_id() -> str:
"""Get fallback trace ID when OTEL is unavailable (32 hex chars)."""
return _trace_id.get()
def init_request_context() -> None:
"""Initialize request context. Call at start of each request."""
req_id = uuid.uuid4().hex[:10]
trace_id = uuid.uuid5(uuid.NAMESPACE_DNS, req_id).hex
_request_id.set(req_id)
_trace_id.set(trace_id)
def clear_request_context() -> None:
"""Clear request context. Call at end of request (optional)."""
_request_id.set("")
_trace_id.set("")

View File

@@ -0,0 +1,94 @@
"""Logging filters for structured logging."""
import contextlib
import logging
import flask
from core.logging.context import get_request_id, get_trace_id
class TraceContextFilter(logging.Filter):
"""
Filter that adds trace_id and span_id to log records.
Integrates with OpenTelemetry when available, falls back to ContextVar-based trace_id.
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
return True
def _get_otel_context(self) -> tuple[str, str]:
"""Extract trace_id and span_id from OpenTelemetry context."""
with contextlib.suppress(Exception):
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
span = get_current_span()
if span and span.get_span_context():
ctx = span.get_span_context()
if ctx.is_valid and ctx.trace_id != INVALID_TRACE_ID:
trace_id = f"{ctx.trace_id:032x}"
span_id = f"{ctx.span_id:016x}" if ctx.span_id != INVALID_SPAN_ID else ""
return trace_id, span_id
return "", ""
class IdentityContextFilter(logging.Filter):
"""
Filter that adds user identity context to log records.
Extracts tenant_id, user_id, and user_type from Flask-Login current_user.
"""
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:
"""Extract identity from current_user if in request context."""
try:
if not flask.has_request_context():
return {}
from flask_login import current_user
# Check if user is authenticated using the proxy
if not current_user.is_authenticated:
return {}
# Access the underlying user object
user = current_user
from models import Account
from models.model import EndUser
identity: dict[str, str] = {}
if isinstance(user, Account):
if user.current_tenant_id:
identity["tenant_id"] = user.current_tenant_id
identity["user_id"] = user.id
identity["user_type"] = "account"
elif isinstance(user, EndUser):
identity["tenant_id"] = user.tenant_id
identity["user_id"] = user.id
identity["user_type"] = user.type or "end_user"
return identity
except Exception:
return {}

View File

@@ -0,0 +1,107 @@
"""Structured JSON log formatter for Dify."""
import logging
import traceback
from datetime import UTC, datetime
from typing import Any
import orjson
from configs import dify_config
class StructuredJSONFormatter(logging.Formatter):
"""
JSON log formatter following the specified schema:
{
"ts": "ISO 8601 UTC",
"severity": "INFO|ERROR|WARN|DEBUG",
"service": "service name",
"caller": "file:line",
"trace_id": "hex 32",
"span_id": "hex 16",
"identity": { "tenant_id", "user_id", "user_type" },
"message": "log message",
"attributes": { ... },
"stack_trace": "..."
}
"""
SEVERITY_MAP: dict[int, str] = {
logging.DEBUG: "DEBUG",
logging.INFO: "INFO",
logging.WARNING: "WARN",
logging.ERROR: "ERROR",
logging.CRITICAL: "ERROR",
}
def __init__(self, service_name: str | None = None):
super().__init__()
self._service_name = service_name or dify_config.APPLICATION_NAME
def format(self, record: logging.LogRecord) -> str:
log_dict = self._build_log_dict(record)
try:
return orjson.dumps(log_dict).decode("utf-8")
except TypeError:
# Fallback: convert non-serializable objects to string
import json
return json.dumps(log_dict, default=str, ensure_ascii=False)
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
# Core fields
log_dict: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(timespec="milliseconds").replace("+00:00", "Z"),
"severity": self.SEVERITY_MAP.get(record.levelno, "INFO"),
"service": self._service_name,
"caller": f"{record.filename}:{record.lineno}",
"message": record.getMessage(),
}
# Trace context (from TraceContextFilter)
trace_id = getattr(record, "trace_id", "")
span_id = getattr(record, "span_id", "")
if trace_id:
log_dict["trace_id"] = trace_id
if span_id:
log_dict["span_id"] = span_id
# Identity context (from IdentityContextFilter)
identity = self._extract_identity(record)
if identity:
log_dict["identity"] = identity
# Dynamic attributes
attributes = getattr(record, "attributes", None)
if attributes:
log_dict["attributes"] = attributes
# Stack trace for errors with exceptions
if record.exc_info and record.levelno >= logging.ERROR:
log_dict["stack_trace"] = self._format_exception(record.exc_info)
return log_dict
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
tenant_id = getattr(record, "tenant_id", None)
user_id = getattr(record, "user_id", None)
user_type = getattr(record, "user_type", None)
if not any([tenant_id, user_id, user_type]):
return None
identity: dict[str, str] = {}
if tenant_id:
identity["tenant_id"] = tenant_id
if user_id:
identity["user_id"] = user_id
if user_type:
identity["user_type"] = user_type
return identity
def _format_exception(self, exc_info: tuple[Any, ...]) -> str:
if exc_info and exc_info[0] is not None:
return "".join(traceback.format_exception(*exc_info))
return ""

View File

@@ -100,7 +100,6 @@ class SimpleProviderEntity(BaseModel):
label: I18nObject
icon_small: I18nObject | None = None
icon_small_dark: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType]
models: list[AIModelEntity] = []
@@ -123,7 +122,6 @@ class ProviderEntity(BaseModel):
label: I18nObject
description: I18nObject | None = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
icon_small_dark: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
@@ -157,7 +155,6 @@ class ProviderEntity(BaseModel):
provider=self.provider,
label=self.label,
icon_small=self.icon_small,
icon_large=self.icon_large,
supported_model_types=self.supported_model_types,
models=self.models,
)

View File

@@ -285,7 +285,7 @@ class ModelProviderFactory:
"""
Get provider icon
:param provider: provider name
:param icon_type: icon type (icon_small or icon_large)
:param icon_type: icon type (icon_small or icon_small_dark)
:param lang: language (zh_Hans or en_US)
:return: provider icon
"""
@@ -309,13 +309,7 @@ class ModelProviderFactory:
else:
file_name = provider_schema.icon_small_dark.en_US
else:
if not provider_schema.icon_large:
raise ValueError(f"Provider {provider} does not have large icon.")
if lang.lower() == "zh_hans":
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
raise ValueError(f"Unsupported icon type: {icon_type}.")
if not file_name:
raise ValueError(f"Provider {provider} does not have icon.")

View File

@@ -103,6 +103,9 @@ class BasePluginClient:
prepared_headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY
prepared_headers.setdefault("Accept-Encoding", "gzip, deflate, br")
# Inject traceparent header for distributed tracing
self._inject_trace_headers(prepared_headers)
prepared_data: bytes | dict[str, Any] | str | None = (
data if isinstance(data, (bytes, str, dict)) or data is None else None
)
@@ -114,6 +117,31 @@ class BasePluginClient:
return str(url), prepared_headers, prepared_data, params, files
def _inject_trace_headers(self, headers: dict[str, str]) -> None:
"""
Inject W3C traceparent header for distributed tracing.
This ensures trace context is propagated to plugin daemon even if
HTTPXClientInstrumentor doesn't cover module-level httpx functions.
"""
if not dify_config.ENABLE_OTEL:
return
import contextlib
# Skip if already present (case-insensitive check)
for key in headers:
if key.lower() == "traceparent":
return
# Inject traceparent - works as fallback when OTEL instrumentation doesn't cover this call
with contextlib.suppress(Exception):
from core.helper.trace_id_helper import generate_traceparent_header
traceparent = generate_traceparent_header()
if traceparent:
headers["traceparent"] = traceparent
def _stream_request(
self,
method: str,

View File

@@ -331,7 +331,6 @@ class ProviderManager:
provider=provider_schema.provider,
label=provider_schema.label,
icon_small=provider_schema.icon_small,
icon_large=provider_schema.icon_large,
supported_model_types=provider_schema.supported_model_types,
),
)

View File

@@ -66,6 +66,8 @@ class WeaviateVector(BaseVector):
in a Weaviate collection.
"""
_DOCUMENT_ID_PROPERTY = "document_id"
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
"""
Initializes the Weaviate vector store.
@@ -353,15 +355,12 @@ class WeaviateVector(BaseVector):
return []
col = self._client.collections.use(self._collection_name)
props = list({*self._attributes, "document_id", Field.TEXT_KEY.value})
props = list({*self._attributes, self._DOCUMENT_ID_PROPERTY, Field.TEXT_KEY.value})
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))
score_threshold = float(kwargs.get("score_threshold") or 0.0)
@@ -408,10 +407,7 @@ class WeaviateVector(BaseVector):
where = None
doc_ids = kwargs.get("document_ids_filter") or []
if doc_ids:
ors = [Filter.by_property("document_id").equal(x) for x in doc_ids]
where = ors[0]
for f in ors[1:]:
where = where | f
where = Filter.by_property(self._DOCUMENT_ID_PROPERTY).contains_any(doc_ids)
top_k = int(kwargs.get("top_k", 4))

View File

@@ -515,6 +515,7 @@ class DatasetRetrieval:
0
].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
dataset_count = len(available_datasets)
with measure_time() as timer:
cancel_event = threading.Event()
thread_exceptions: list[Exception] = []
@@ -537,6 +538,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": query,
"attachment_id": None,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@@ -562,6 +564,7 @@ class DatasetRetrieval:
"score_threshold": score_threshold,
"query": None,
"attachment_id": attachment_id,
"dataset_count": dataset_count,
"cancel_event": cancel_event,
"thread_exceptions": thread_exceptions,
},
@@ -1422,6 +1425,7 @@ class DatasetRetrieval:
score_threshold: float,
query: str | None,
attachment_id: str | None,
dataset_count: int,
cancel_event: threading.Event | None = None,
thread_exceptions: list[Exception] | None = None,
):
@@ -1470,37 +1474,38 @@ class DatasetRetrieval:
if cancel_event and cancel_event.is_set():
break
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()

View File

@@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.
### Event-Driven Architecture
All node executions emit events for monitoring and integration:

View File

@@ -1 +0,0 @@
from .runner import WorkflowGenerator

View File

@@ -1,29 +0,0 @@
"""
Vibe Workflow Generator Configuration Module.
This module centralizes configuration for the Vibe workflow generation feature,
including node schemas, fallback rules, and response templates.
"""
from core.workflow.generator.config.node_schemas import (
BUILTIN_NODE_SCHEMAS,
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_builtin_node_schemas,
get_corrected_field_name,
validate_node_schemas,
)
from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
__all__ = [
"BUILTIN_NODE_SCHEMAS",
"DEFAULT_SUGGESTIONS",
"FALLBACK_RULES",
"FIELD_NAME_CORRECTIONS",
"NODE_TYPE_ALIASES",
"OFF_TOPIC_RESPONSES",
"get_builtin_node_schemas",
"get_corrected_field_name",
"validate_node_schemas",
]

View File

@@ -1,501 +0,0 @@
"""
Unified Node Configuration for Vibe Workflow Generation.
This module centralizes all node-related configuration:
- Node schemas (parameter definitions)
- Fallback rules (keyword-based node type inference)
- Node type aliases (natural language to canonical type mapping)
- Field name corrections (LLM output normalization)
- Validation utilities
Note: These definitions are the single source of truth.
Frontend has a mirrored copy at web/app/components/workflow/hooks/use-workflow-vibe-config.ts
"""
from typing import Any
# =============================================================================
# NODE SCHEMAS
# =============================================================================
# Built-in node schemas with parameter definitions
# These help the model understand what config each node type requires
_HARDCODED_SCHEMAS: dict[str, dict[str, Any]] = {
"http-request": {
"description": "Send HTTP requests to external APIs or fetch web content",
"required": ["url", "method"],
"parameters": {
"url": {
"type": "string",
"description": "Full URL including protocol (https://...)",
"example": "{{#start.url#}} or https://api.example.com/data",
},
"method": {
"type": "enum",
"options": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"],
"description": "HTTP method",
},
"headers": {
"type": "string",
"description": "HTTP headers as newline-separated 'Key: Value' pairs",
"example": "Content-Type: application/json\nAuthorization: Bearer {{#start.api_key#}}",
},
"params": {
"type": "string",
"description": "URL query parameters as newline-separated 'key: value' pairs",
},
"body": {
"type": "object",
"description": "Request body with type field required",
"example": {"type": "none", "data": []},
},
"authorization": {
"type": "object",
"description": "Authorization config",
"example": {"type": "no-auth"},
},
"timeout": {
"type": "number",
"description": "Request timeout in seconds",
"default": 60,
},
},
"outputs": ["body (response content)", "status_code", "headers"],
},
"code": {
"description": "Execute Python or JavaScript code for custom logic",
"required": ["code", "language"],
"parameters": {
"code": {
"type": "string",
"description": "Code to execute. Must define a main() function that returns a dict.",
},
"language": {
"type": "enum",
"options": ["python3", "javascript"],
},
"variables": {
"type": "array",
"description": "Input variables passed to the code",
"item_schema": {"variable": "string", "value_selector": "array"},
},
"outputs": {
"type": "object",
"description": "Output variable definitions",
},
},
"outputs": ["Variables defined in outputs schema"],
},
"llm": {
"description": "Call a large language model for text generation/processing",
"required": ["prompt_template"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"prompt_template": {
"type": "array",
"description": "Messages for the LLM",
"item_schema": {
"role": "enum: system, user, assistant",
"text": "string - message content, can include {{#node_id.field#}} references",
},
},
"context": {
"type": "object",
"description": "Optional context settings",
},
"memory": {
"type": "object",
"description": "Optional memory/conversation settings",
},
},
"outputs": ["text (generated response)"],
},
"if-else": {
"description": "Conditional branching based on conditions",
"required": ["cases"],
"parameters": {
"cases": {
"type": "array",
"description": "List of condition cases. Each case defines when 'true' branch is taken.",
"item_schema": {
"case_id": "string - unique case identifier (e.g., 'case_1')",
"logical_operator": "enum: and, or - how multiple conditions combine",
"conditions": {
"type": "array",
"item_schema": {
"variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']",
"comparison_operator": (
"enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty"
),
"value": "string or number - value to compare against",
},
},
},
},
},
"outputs": ["Branches: true (first case conditions met), false (else/no case matched)"],
},
"knowledge-retrieval": {
"description": "Query knowledge base for relevant content",
"required": ["query_variable_selector", "dataset_ids"],
"parameters": {
"query_variable_selector": {
"type": "array",
"description": "Path to query variable, e.g. ['start', 'query']",
},
"dataset_ids": {
"type": "array",
"description": "List of knowledge base IDs to search",
},
"retrieval_mode": {
"type": "enum",
"options": ["single", "multiple"],
},
},
"outputs": ["result (retrieved documents)"],
},
"template-transform": {
"description": "Transform data using Jinja2 templates",
"required": ["template", "variables"],
"parameters": {
"template": {
"type": "string",
"description": "Jinja2 template string. Use {{ variable_name }} to reference variables.",
},
"variables": {
"type": "array",
"description": "Input variables defined for the template",
"item_schema": {
"variable": "string - variable name to use in template",
"value_selector": "array - path to source value, e.g. ['start', 'user_input']",
},
},
},
"outputs": ["output (transformed string)"],
},
"variable-aggregator": {
"description": "Aggregate variables from multiple branches",
"required": ["variables"],
"parameters": {
"variables": {
"type": "array",
"description": "List of variable selectors to aggregate",
"item_schema": "array of strings - path to source variable, e.g. ['node_id', 'field']",
},
},
"outputs": ["output (aggregated value)"],
},
"iteration": {
"description": "Loop over array items",
"required": ["iterator_selector"],
"parameters": {
"iterator_selector": {
"type": "array",
"description": "Path to array variable to iterate",
},
},
"outputs": ["item (current iteration item)", "index (current index)"],
},
"parameter-extractor": {
"description": "Extract structured parameters from user input using LLM",
"required": ["query", "parameters"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']",
},
"parameters": {
"type": "array",
"description": "Parameters to extract from the input",
"item_schema": {
"name": "string - parameter name (required)",
"type": (
"enum: string, number, boolean, array[string], array[number], array[object], array[boolean]"
),
"description": "string - description of what to extract (required)",
"required": "boolean - whether this parameter is required (MUST be specified)",
"options": "array of strings (optional) - for enum-like selection",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for extraction",
},
"reasoning_mode": {
"type": "enum",
"options": ["function_call", "prompt"],
"description": "How to perform extraction (defaults to function_call)",
},
},
"outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"],
},
"question-classifier": {
"description": "Classify user input into predefined categories using LLM",
"required": ["query", "classes"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to classify, e.g. ['start', 'user_input']",
},
"classes": {
"type": "array",
"description": "Classification categories",
"item_schema": {
"id": "string - unique class identifier",
"name": "string - class name/label",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for classification",
},
},
"outputs": ["class_name (selected class)"],
},
}
def _get_dynamic_schemas() -> dict[str, dict[str, Any]]:
"""
Dynamically load schemas from node classes.
Uses lazy import to avoid circular dependency.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
schemas = {}
for node_type, version_map in NODE_TYPE_CLASSES_MAPPING.items():
# Get the latest version class
node_cls = version_map.get(LATEST_VERSION)
if not node_cls:
continue
# Get schema from the class
schema = node_cls.get_default_config_schema()
if schema:
schemas[node_type.value] = schema
return schemas
# Cache for built-in schemas (populated on first access)
_builtin_schemas_cache: dict[str, dict[str, Any]] | None = None
def get_builtin_node_schemas() -> dict[str, dict[str, Any]]:
"""
Get the complete set of built-in node schemas.
Combines hardcoded schemas with dynamically loaded ones.
Results are cached after first call.
"""
global _builtin_schemas_cache
if _builtin_schemas_cache is None:
_builtin_schemas_cache = {**_HARDCODED_SCHEMAS, **_get_dynamic_schemas()}
return _builtin_schemas_cache
# For backward compatibility - but use get_builtin_node_schemas() for lazy loading
BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = _HARDCODED_SCHEMAS.copy()
# =============================================================================
# FALLBACK RULES
# =============================================================================
# Keyword rules for smart fallback detection
# Maps node type to keywords that suggest using that node type as a fallback
FALLBACK_RULES: dict[str, list[str]] = {
"http-request": [
"http",
"url",
"web",
"scrape",
"scraper",
"fetch",
"api",
"request",
"download",
"upload",
"webhook",
"endpoint",
"rest",
"get",
"post",
],
"code": [
"code",
"script",
"calculate",
"compute",
"process",
"transform",
"parse",
"convert",
"format",
"filter",
"sort",
"math",
"logic",
],
"llm": [
"analyze",
"summarize",
"summary",
"extract",
"classify",
"translate",
"generate",
"write",
"rewrite",
"explain",
"answer",
"chat",
],
}
# =============================================================================
# NODE TYPE ALIASES
# =============================================================================
# Node type aliases for inference from natural language
# Maps common terms to canonical node type names
NODE_TYPE_ALIASES: dict[str, str] = {
# Start node aliases
"start": "start",
"begin": "start",
"input": "start",
# End node aliases
"end": "end",
"finish": "end",
"output": "end",
# LLM node aliases
"llm": "llm",
"ai": "llm",
"gpt": "llm",
"model": "llm",
"chat": "llm",
# Code node aliases
"code": "code",
"script": "code",
"python": "code",
"javascript": "code",
# HTTP request node aliases
"http-request": "http-request",
"http": "http-request",
"request": "http-request",
"api": "http-request",
"fetch": "http-request",
"webhook": "http-request",
# Conditional node aliases
"if-else": "if-else",
"condition": "if-else",
"branch": "if-else",
"switch": "if-else",
# Loop node aliases
"iteration": "iteration",
"loop": "loop",
"foreach": "iteration",
# Tool node alias
"tool": "tool",
}
# =============================================================================
# FIELD NAME CORRECTIONS
# =============================================================================
# Field name corrections for LLM-generated node configs
# Maps incorrect field names to correct ones for specific node types
FIELD_NAME_CORRECTIONS: dict[str, dict[str, str]] = {
"http-request": {
"text": "body", # LLM might use "text" instead of "body"
"content": "body",
"response": "body",
},
"code": {
"text": "result", # LLM might use "text" instead of "result"
"output": "result",
},
"llm": {
"response": "text",
"answer": "text",
},
}
def get_corrected_field_name(node_type: str, field: str) -> str:
"""
Get the corrected field name for a node type.
Args:
node_type: The type of the node (e.g., "http-request", "code")
field: The field name to correct
Returns:
The corrected field name, or the original if no correction needed
"""
corrections = FIELD_NAME_CORRECTIONS.get(node_type, {})
return corrections.get(field, field)
# =============================================================================
# VALIDATION UTILITIES
# =============================================================================
# Node types that are internal and don't need schemas for LLM generation
_INTERNAL_NODE_TYPES: set[str] = {
# Internal workflow nodes
"answer", # Internal to chatflow
"loop", # Uses iteration internally
"assigner", # Variable assignment utility
"variable-assigner", # Variable assignment utility
"agent", # Agent node (complex, handled separately)
"document-extractor", # Internal document processing
"list-operator", # Internal list operations
# Iteration internal nodes
"iteration-start", # Internal to iteration loop
"loop-start", # Internal to loop
"loop-end", # Internal to loop
# Trigger nodes (not user-creatable via LLM)
"trigger-plugin", # Plugin trigger
"trigger-schedule", # Scheduled trigger
"trigger-webhook", # Webhook trigger
# Other internal nodes
"datasource", # Data source configuration
"human-input", # Human-in-the-loop node
"knowledge-index", # Knowledge indexing node
}
def validate_node_schemas() -> list[str]:
"""
Validate that all registered node types have corresponding schemas.
This function checks if BUILTIN_NODE_SCHEMAS covers all node types
registered in NODE_TYPE_CLASSES_MAPPING, excluding internal node types.
Returns:
List of warning messages for missing schemas (empty if all valid)
"""
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
schemas = get_builtin_node_schemas()
warnings = []
for node_type in NODE_TYPE_CLASSES_MAPPING:
type_value = node_type.value
if type_value in _INTERNAL_NODE_TYPES:
continue
if type_value not in schemas:
warnings.append(f"Missing schema for node type: {type_value}")
return warnings

View File

@@ -1,72 +0,0 @@
"""
Response Templates for Vibe Workflow Generation.
This module defines templates for off-topic responses and default suggestions
to guide users back to workflow-related requests.
"""
# Off-topic response templates for different categories
# Each category has messages in multiple languages
OFF_TOPIC_RESPONSES: dict[str, dict[str, str]] = {
"weather": {
"en": (
"I'm the workflow design assistant - I can't check the weather, "
"but I can help you build AI workflows! For example, I could help you "
"create a workflow that fetches weather data from an API."
),
"zh": "我是工作流设计助手无法查询天气。但我可以帮你创建一个从API获取天气数据的工作流",
},
"math": {
"en": (
"I focus on workflow design rather than calculations. However, "
"if you need calculations in a workflow, I can help you add a Code node "
"that handles math operations!"
),
"zh": "我专注于工作流设计而非计算。但如果您需要在工作流中进行计算,我可以帮您添加一个处理数学运算的代码节点!",
},
"joke": {
"en": (
"While I'd love to share a laugh, I'm specialized in workflow design. "
"How about we create something fun instead - like a workflow that generates jokes using AI?"
),
"zh": "虽然我很想讲笑话但我专门从事工作流设计。不如我们创建一个有趣的东西——比如使用AI生成笑话的工作流",
},
"translation": {
"en": (
"I can't translate directly, but I can help you build a translation workflow! "
"Would you like to create one using an LLM node?"
),
"zh": "我不能直接翻译但我可以帮你构建一个翻译工作流要创建一个使用LLM节点的翻译流程吗",
},
"general_coding": {
"en": (
"I'm specialized in Dify workflow design rather than general coding help. "
"But if you want to add code logic to your workflow, I can help you configure a Code node!"
),
"zh": (
"我专注于Dify工作流设计而非通用编程帮助。但如果您想在工作流中添加代码逻辑我可以帮您配置一个代码节点"
),
},
"default": {
"en": (
"I'm the Dify workflow design assistant. I help create AI automation workflows, "
"but I can't help with general questions. Would you like to create a workflow instead?"
),
"zh": "我是Dify工作流设计助手。我帮助创建AI自动化工作流但无法回答一般性问题。您想创建一个工作流吗",
},
}
# Default suggestions for off-topic requests
# These help guide users towards valid workflow requests
DEFAULT_SUGGESTIONS: dict[str, list[str]] = {
"en": [
"Create a chatbot workflow",
"Build a document summarization pipeline",
"Add email notification to workflow",
],
"zh": [
"创建一个聊天机器人工作流",
"构建文档摘要处理流程",
"添加邮件通知到工作流",
],
}

View File

@@ -1,457 +0,0 @@
BUILDER_SYSTEM_PROMPT = """<role>
You are a Workflow Configuration Engineer.
Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration.
</role>
<language_rules>
- Detect the language of the user's request automatically (e.g., English, Chinese, Japanese, etc.).
- Generate ALL node titles, descriptions, and user-facing text in the SAME language as the user's input.
- If the input language is ambiguous or cannot be determined (e.g. code-only input),
use {preferred_language} as the target language.
</language_rules>
<inputs>
<plan>
{plan_context}
</plan>
<tool_schemas>
{tool_schemas}
</tool_schemas>
<node_specs>
{builtin_node_specs}
</node_specs>
<available_models>
{available_models}
</available_models>
<workflow_context>
<existing_nodes>
{existing_nodes_context}
</existing_nodes>
<existing_edges>
{existing_edges_context}
</existing_edges>
<selected_nodes>
{selected_nodes_context}
</selected_nodes>
</workflow_context>
</inputs>
<rules>
1. **Configuration**:
- You MUST fill ALL required parameters for every node.
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
- For 'start' node, define all necessary user inputs.
2. **Variable References**:
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
- For 'end' node outputs: use `value_selector` array format `["node_id", "field"]`
- Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]`
3. **Tools**:
- ONLY use the tools listed in `<tool_schemas>`.
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
4. **Model Selection** (CRITICAL):
- For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config.
- You MUST use ONLY models from the `<available_models>` section above.
- Copy the EXACT provider and name values from available_models.
- NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models.
- If available_models is empty or shows "No models configured", omit the model config entirely.
5. **Node Specifics**:
- For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`).
6. **Modification Mode**:
- If `<existing_nodes>` contains nodes, you are MODIFYING an existing workflow.
- Keep nodes that are NOT mentioned in the user's instruction UNCHANGED.
- Only modify/add/remove nodes that the user explicitly requested.
- Preserve node IDs for unchanged nodes to maintain connections.
- If user says "add X", append new nodes to existing workflow.
- If user says "change Y to Z", only modify that specific node.
- If user says "remove X", exclude that node from output.
**Edge Modification**:
- Use `<existing_edges>` to understand current node connections.
- If user mentions "fix edge", "connect", "link", or "add connection",
review existing_edges and correct missing/wrong connections.
- For multi-branch nodes (if-else, question-classifier),
ensure EACH branch has proper sourceHandle (e.g., "true"/"false") and target.
- Common edge issues to fix:
* Missing edge: Two nodes should connect but don't - add the edge
* Wrong target: Edge points to wrong node - update the target
* Missing sourceHandle: if-else/classifier branches lack sourceHandle - add "true"/"false"
* Disconnected nodes: Node has no incoming or outgoing edges - connect it properly
- When modifying edges, ensure logical flow makes sense (start → middle → end).
- ALWAYS output complete edges array, even if only modifying one edge.
**Validation Feedback** (Automatic Retry):
- If `<validation_feedback>` is present, you are RETRYING after validation errors.
- Focus ONLY on fixing the specific validation issues mentioned.
- Keep everything else from the previous attempt UNCHANGED (preserve node IDs, edges, etc).
- Common validation issues and fixes:
* "Missing required connection" → Add the missing edge
* "Invalid node configuration" → Fix the specific node's config section
* "Type mismatch in variable reference" → Correct the variable selector path
* "Unknown variable" → Update variable reference to existing output
- When fixing, make MINIMAL changes to address each specific error.
7. **Output**:
- Return ONLY the JSON object with `nodes` and `edges`.
- Do NOT generate Mermaid diagrams.
- Do NOT generate explanations.
</rules>
<edge_rules priority="critical">
**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge.
1. **Linear Flow**: Simple source -> target connection
```
{{"source": "node_a", "target": "node_b"}}
```
2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id`
- If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}]
- You MUST create edges:
- {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}}
- {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}}
3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false"
- {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}}
- {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}}
4. **Branch Convergence**: Multiple branches can connect to same downstream node
- Both true_branch and false_branch can connect to the same 'end' node
5. **NEVER leave orphan nodes**: Every node must be connected in the graph
</edge_rules>
<examples>
<example name="simple_linear">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}]
}}
}},
{{
"id": "llm",
"type": "llm",
"title": "Generate Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [
{{"variable": "result", "value_selector": ["llm", "text"]}}
]
}}
}}
],
"edges": [
{{"source": "start", "target": "llm"}},
{{"source": "llm", "target": "end"}}
]
}}
```
</example>
<example name="question_classifier_branching" description="Customer service with intent classification">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}]
}}
}},
{{
"id": "classifier",
"type": "question-classifier",
"title": "Classify Intent",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query_variable_selector": ["start", "user_input"],
"classes": [
{{"id": "cls_refund", "name": "Refund Request"}},
{{"id": "cls_inquiry", "name": "Product Inquiry"}},
{{"id": "cls_complaint", "name": "Complaint"}},
{{"id": "cls_other", "name": "Other"}}
],
"instruction": "Classify the user's intent"
}}
}},
{{
"id": "handle_refund",
"type": "llm",
"title": "Handle Refund",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_inquiry",
"type": "llm",
"title": "Handle Inquiry",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_complaint",
"type": "llm",
"title": "Handle Complaint",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_other",
"type": "llm",
"title": "Handle Other",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "classifier"}},
{{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}},
{{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}},
{{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}},
{{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}},
{{"source": "handle_refund", "target": "end"}},
{{"source": "handle_inquiry", "target": "end"}},
{{"source": "handle_complaint", "target": "end"}},
{{"source": "handle_other", "target": "end"}}
]
}}
```
CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges!
</example>
<example name="if_else_branching" description="Conditional logic with if-else">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}]
}}
}},
{{
"id": "check_experience",
"type": "if-else",
"title": "Check Experience",
"config": {{
"cases": [
{{
"case_id": "case_1",
"logical_operator": "and",
"conditions": [
{{
"variable_selector": ["start", "years"],
"comparison_operator": "",
"value": "3"
}}
]
}}
]
}}
}},
{{
"id": "qualified",
"type": "llm",
"title": "Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}]
}}
}},
{{
"id": "not_qualified",
"type": "llm",
"title": "Not Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate rejection response"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "check_experience"}},
{{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}},
{{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}},
{{"source": "qualified", "target": "end"}},
{{"source": "not_qualified", "target": "end"}}
]
}}
```
CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"!
</example>
<example name="parameter_extractor" description="Extract structured data from text">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}]
}}
}},
{{
"id": "extract",
"type": "parameter-extractor",
"title": "Extract Info",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query": ["start", "resume"],
"parameters": [
{{"name": "name", "type": "string", "description": "Candidate name", "required": true}},
{{"name": "years", "type": "number", "description": "Years of experience", "required": true}},
{{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}}
],
"instruction": "Extract candidate information from resume"
}}
}},
{{
"id": "process",
"type": "llm",
"title": "Process Data",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "extract"}},
{{"source": "extract", "target": "process"}},
{{"source": "process", "target": "end"}}
]
}}
```
</example>
</examples>
<edge_checklist>
Before finalizing, verify:
1. [ ] Every node (except 'end') has at least one outgoing edge
2. [ ] 'start' node has exactly one outgoing edge
3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id
4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false"
5. [ ] All branches eventually connect to 'end' (directly or through other nodes)
6. [ ] No orphan nodes exist (every node is reachable from 'start')
</edge_checklist>
"""
BUILDER_USER_PROMPT = """<instruction>
{instruction}
</instruction>
Generate the full workflow configuration now. Pay special attention to:
1. Creating edges for ALL branches of question-classifier and if-else nodes
2. Using correct sourceHandle values for branching nodes
3. Ensuring every node is connected in the graph
"""
def format_existing_nodes(nodes: list[dict] | None) -> str:
"""Format existing workflow nodes for context."""
if not nodes:
return "No existing nodes in workflow (creating from scratch)."
lines = []
for node in nodes:
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
lines.append(f"- [{node_id}] {title} ({node_type})")
return "\n".join(lines)
def format_selected_nodes(
selected_ids: list[str] | None,
existing_nodes: list[dict] | None,
) -> str:
"""Format selected nodes for modification context."""
if not selected_ids:
return "No nodes selected (generating new workflow)."
node_map = {n.get("id"): n for n in (existing_nodes or [])}
lines = []
for node_id in selected_ids:
if node_id in node_map:
node = node_map[node_id]
lines.append(f"- [{node_id}] {node.get('title', 'Untitled')} ({node.get('type', 'unknown')})")
else:
lines.append(f"- [{node_id}] (not found in current workflow)")
return "\n".join(lines)
def format_existing_edges(edges: list[dict] | None) -> str:
"""Format existing workflow edges to show connections."""
if not edges:
return "No existing edges (creating new workflow)."
lines = []
for edge in edges:
source = edge.get("source", "unknown")
target = edge.get("target", "unknown")
source_handle = edge.get("sourceHandle", "")
if source_handle:
lines.append(f"- {source} ({source_handle}) -> {target}")
else:
lines.append(f"- {source} -> {target}")
return "\n".join(lines)

View File

@@ -1,75 +0,0 @@
PLANNER_SYSTEM_PROMPT = """<role>
You are an expert Workflow Architect.
Your job is to analyze user requests and plan a high-level automation workflow.
</role>
<task>
1. **Classify Intent**:
- Is the user asking to create an automation/workflow? -> Intent: "generate"
- Is it general chat/weather/jokes? -> Intent: "off_topic"
2. **Plan Steps** (if intent is "generate"):
- Break down the user's goal into logical steps.
- For each step, identify if a specific capability/tool is needed.
- Select the MOST RELEVANT tools from the available_tools list.
- DO NOT configure parameters yet. Just identify the tool.
3. **Output Format**:
Return a JSON object.
</task>
<available_tools>
{tools_summary}
</available_tools>
<response_format>
If intent is "generate":
```json
{{
"intent": "generate",
"plan_thought": "Brief explanation of the plan...",
"steps": [
{{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }},
{{ "step": 2, "description": "Summarize content", "tool": "llm" }},
{{ "step": 3, "description": "Search for info", "tool": "google_search" }}
],
"required_tool_keys": ["google_search"]
}}
```
(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys,
only external tools)
If intent is "off_topic":
```json
{{
"intent": "off_topic",
"message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'",
"suggestions": ["Scrape a website", "Summarize a PDF"]
}}
```
</response_format>
"""
PLANNER_USER_PROMPT = """<user_request>
{instruction}
</user_request>
"""
def format_tools_for_planner(tools: list[dict]) -> str:
"""Format tools list for planner (Lightweight: Name + Description only)."""
if not tools:
return "No external tools available."
lines = []
for t in tools:
key = t.get("tool_key") or t.get("tool_name")
provider = t.get("provider_id") or t.get("provider", "")
desc = t.get("tool_description") or t.get("description", "")
label = t.get("tool_label") or key
# Format: - [provider/key] Label: Description
full_key = f"{provider}/{key}" if provider else key
lines.append(f"- [{full_key}] {label}: {desc}")
return "\n".join(lines)

File diff suppressed because it is too large Load Diff

View File

@@ -1,287 +0,0 @@
import json
import logging
import re
from collections.abc import Sequence
import json_repair
from core.model_manager import ModelManager
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.workflow.generator.prompts.builder_prompts import (
BUILDER_SYSTEM_PROMPT,
BUILDER_USER_PROMPT,
format_existing_edges,
format_existing_nodes,
format_selected_nodes,
)
from core.workflow.generator.prompts.planner_prompts import (
PLANNER_SYSTEM_PROMPT,
PLANNER_USER_PROMPT,
format_tools_for_planner,
)
from core.workflow.generator.prompts.vibe_prompts import (
format_available_models,
format_available_nodes,
format_available_tools,
parse_vibe_response,
)
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
logger = logging.getLogger(__name__)
class WorkflowGenerator:
"""
Refactored Vibe Workflow Generator (Planner-Builder Architecture).
Extracts Vibe logic from the monolithic LLMGenerator.
"""
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
existing_edges: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
):
"""
Generates a Dify Workflow Flowchart from natural language instruction.
Pipeline:
1. Planner: Analyze intent & select tools.
2. Context Filter: Filter relevant tools (reduce tokens).
3. Builder: Generate node configurations.
4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair).
5. Validator: Check for errors & generate friendly hints.
6. Renderer: Deterministic Mermaid generation.
"""
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
model_parameters = model_config.get("completion_params", {})
available_tools_list = list(available_tools) if available_tools else []
# Check if this is modification mode (user is refining existing workflow)
has_existing_nodes = existing_nodes and len(list(existing_nodes)) > 0
# --- STEP 1: PLANNER (Skip in modification mode) ---
if has_existing_nodes:
# In modification mode, skip Planner:
# - User intent is clear: modify the existing workflow
# - Tools are already in use (from existing nodes)
# - No need for intent classification or tool selection
plan_data = {"intent": "generate", "steps": [], "required_tool_keys": []}
filtered_tools = available_tools_list # Use all available tools
else:
# In creation mode, run Planner to validate intent and select tools
planner_tools_context = format_tools_for_planner(available_tools_list)
planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context)
planner_user = PLANNER_USER_PROMPT.format(instruction=instruction)
try:
response = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content=planner_system),
UserPromptMessage(content=planner_user),
],
model_parameters=model_parameters,
stream=False,
)
plan_content = response.message.content
# Reuse parse_vibe_response logic or simple load
plan_data = parse_vibe_response(plan_content)
except Exception as e:
logger.exception("Planner failed")
return {"intent": "error", "error": f"Planning failed: {str(e)}"}
if plan_data.get("intent") == "off_topic":
return {
"intent": "off_topic",
"message": plan_data.get("message", "I can only help with workflow creation."),
"suggestions": plan_data.get("suggestions", []),
}
# --- STEP 2: CONTEXT FILTERING ---
required_tools = plan_data.get("required_tool_keys", [])
filtered_tools = []
if required_tools:
# Simple linear search (optimized version would use a map)
for tool in available_tools_list:
t_key = tool.get("tool_key") or tool.get("tool_name")
provider = tool.get("provider_id") or tool.get("provider")
full_key = f"{provider}/{t_key}" if provider else t_key
# Check if this tool is in required list (match either full key or short name)
if t_key in required_tools or full_key in required_tools:
filtered_tools.append(tool)
else:
# If logic only, no tools needed
filtered_tools = []
# --- STEP 3: BUILDER (with retry loop) ---
MAX_GLOBAL_RETRIES = 2 # Total attempts: 1 initial + 1 retry
workflow_data = None
mermaid_code = None
all_warnings = []
all_fixes = []
retry_count = 0
validation_hints = []
for attempt in range(MAX_GLOBAL_RETRIES):
retry_count = attempt
logger.info("Generation attempt %s/%s", attempt + 1, MAX_GLOBAL_RETRIES)
# Prepare context
tool_schemas = format_available_tools(filtered_tools)
node_specs = format_available_nodes(list(available_nodes) if available_nodes else [])
existing_nodes_context = format_existing_nodes(list(existing_nodes) if existing_nodes else None)
existing_edges_context = format_existing_edges(list(existing_edges) if existing_edges else None)
selected_nodes_context = format_selected_nodes(
list(selected_node_ids) if selected_node_ids else None, list(existing_nodes) if existing_nodes else None
)
# Build retry context
retry_context = ""
# NOTE: Manual regeneration/refinement mode removed
# Only handle automatic retry (validation errors)
# For automatic retry (validation errors)
if attempt > 0 and validation_hints:
severe_issues = [h for h in validation_hints if h.severity == "error"]
if severe_issues:
retry_context = "\n<validation_feedback>\n"
retry_context += "The previous generation had validation errors:\n"
for idx, hint in enumerate(severe_issues[:5], 1):
retry_context += f"{idx}. {hint.message}\n"
retry_context += "\nPlease fix these specific issues while keeping everything else UNCHANGED.\n"
retry_context += "</validation_feedback>\n"
builder_system = BUILDER_SYSTEM_PROMPT.format(
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
tool_schemas=tool_schemas,
builtin_node_specs=node_specs,
available_models=format_available_models(list(available_models or [])),
preferred_language=preferred_language or "English",
existing_nodes_context=existing_nodes_context,
existing_edges_context=existing_edges_context,
selected_nodes_context=selected_nodes_context,
)
builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + retry_context
try:
build_res = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content=builder_system),
UserPromptMessage(content=builder_user),
],
model_parameters=model_parameters,
stream=False,
)
# Builder output is raw JSON nodes/edges
build_content = build_res.message.content
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content)
if match:
build_content = match.group(1)
workflow_data = json_repair.loads(build_content)
if "nodes" not in workflow_data:
workflow_data["nodes"] = []
if "edges" not in workflow_data:
workflow_data["edges"] = []
except Exception as e:
logger.exception("Builder failed on attempt %d", attempt + 1)
if attempt == MAX_GLOBAL_RETRIES - 1:
return {"intent": "error", "error": f"Building failed: {str(e)}"}
continue # Try again
# NOTE: NodeRepair and EdgeRepair have been removed.
# Validation will detect structural issues, and LLM will fix them on retry.
# This is more accurate because LLM understands the workflow context.
# --- STEP 4: RENDERER (Generate Mermaid early for validation) ---
mermaid_code = generate_mermaid(workflow_data)
# --- STEP 5: VALIDATOR ---
is_valid, validation_hints = WorkflowValidator.validate(workflow_data, available_tools_list)
# --- STEP 6: GRAPH VALIDATION (structural checks using graph algorithms) ---
if attempt < MAX_GLOBAL_RETRIES - 1:
try:
from core.workflow.generator.utils.graph_validator import GraphValidator
graph_result = GraphValidator.validate(workflow_data)
if not graph_result.success:
# Convert graph errors to validation hints
for graph_error in graph_result.errors:
validation_hints.append(
ValidationHint(
node_id=graph_error.node_id,
field="edges",
message=f"[Graph] {graph_error.message}",
severity="error",
)
)
# Also add warnings (dead ends) as hints
for graph_warning in graph_result.warnings:
validation_hints.append(
ValidationHint(
node_id=graph_warning.node_id,
field="edges",
message=f"[Graph] {graph_warning.message}",
severity="warning",
)
)
except Exception as e:
logger.warning("Graph validation error: %s", e)
# Collect all validation warnings
all_warnings = [h.message for h in validation_hints]
# Check if we should retry
severe_issues = [h for h in validation_hints if h.severity == "error"]
if not severe_issues or attempt == MAX_GLOBAL_RETRIES - 1:
break
# Has severe errors and retries remaining - continue to next attempt
# Collect all validation warnings
all_warnings = [h.message for h in validation_hints]
# Add stability warning (as requested by user)
stability_warning = "The generated workflow may require debugging."
if preferred_language and preferred_language.startswith("zh"):
stability_warning = "生成的 Workflow 可能需要调试。"
all_warnings.append(stability_warning)
return {
"intent": "generate",
"flowchart": mermaid_code,
"nodes": workflow_data["nodes"],
"edges": workflow_data["edges"],
"message": plan_data.get("plan_thought", "Generated workflow based on your request."),
"warnings": all_warnings,
"tool_recommendations": [], # Legacy field
"error": "",
"fixed_issues": all_fixes, # Track what was auto-fixed
"retry_count": retry_count, # Track how many retries were needed
}

View File

@@ -1,217 +0,0 @@
"""
Type definitions for Vibe Workflow Generator.
This module provides:
- TypedDict classes for lightweight type hints (no runtime overhead)
- Pydantic models for runtime validation where needed
Usage:
# For type hints only (no runtime validation):
from core.workflow.generator.types import WorkflowNodeDict, WorkflowEdgeDict
# For runtime validation:
from core.workflow.generator.types import WorkflowNode, WorkflowEdge
"""
from typing import Any, TypedDict
from pydantic import BaseModel, Field
# ============================================================
# TypedDict definitions (lightweight, for type hints only)
# ============================================================
class WorkflowNodeDict(TypedDict, total=False):
"""
Workflow node structure (TypedDict for hints).
Attributes:
id: Unique node identifier
type: Node type (e.g., "start", "end", "llm", "if-else", "http-request")
title: Human-readable node title
config: Node-specific configuration
data: Additional node data
"""
id: str
type: str
title: str
config: dict[str, Any]
data: dict[str, Any]
class WorkflowEdgeDict(TypedDict, total=False):
"""
Workflow edge structure (TypedDict for hints).
Attributes:
source: Source node ID
target: Target node ID
sourceHandle: Branch handle for if-else/question-classifier nodes
"""
source: str
target: str
sourceHandle: str
class AvailableModelDict(TypedDict):
"""
Available model structure.
Attributes:
provider: Model provider (e.g., "openai", "anthropic")
model: Model name (e.g., "gpt-4", "claude-3")
"""
provider: str
model: str
class ToolParameterDict(TypedDict, total=False):
"""
Tool parameter structure.
Attributes:
name: Parameter name
type: Parameter type (e.g., "string", "number", "boolean")
required: Whether parameter is required
human_description: Human-readable description
llm_description: LLM-oriented description
options: Available options for enum-type parameters
"""
name: str
type: str
required: bool
human_description: str | dict[str, str]
llm_description: str
options: list[Any]
class AvailableToolDict(TypedDict, total=False):
"""
Available tool structure.
Attributes:
provider_id: Tool provider ID
provider: Tool provider name (alternative to provider_id)
tool_key: Unique tool key
tool_name: Tool name (alternative to tool_key)
tool_description: Tool description
description: Alternative description field
is_team_authorization: Whether tool is configured/authorized
parameters: List of tool parameters
"""
provider_id: str
provider: str
tool_key: str
tool_name: str
tool_description: str
description: str
is_team_authorization: bool
parameters: list[ToolParameterDict]
class WorkflowDataDict(TypedDict, total=False):
"""
Complete workflow data structure.
Attributes:
nodes: List of workflow nodes
edges: List of workflow edges
warnings: List of warning messages
"""
nodes: list[WorkflowNodeDict]
edges: list[WorkflowEdgeDict]
warnings: list[str]
# ============================================================
# Pydantic models (for runtime validation)
# ============================================================
class WorkflowNode(BaseModel):
"""
Workflow node with runtime validation.
Use this model when you need to validate node data at runtime.
For lightweight type hints without validation, use WorkflowNodeDict.
"""
id: str
type: str
title: str = ""
config: dict[str, Any] = Field(default_factory=dict)
data: dict[str, Any] = Field(default_factory=dict)
class WorkflowEdge(BaseModel):
"""
Workflow edge with runtime validation.
Use this model when you need to validate edge data at runtime.
For lightweight type hints without validation, use WorkflowEdgeDict.
"""
source: str
target: str
sourceHandle: str | None = None
class AvailableModel(BaseModel):
"""
Available model with runtime validation.
Use this model when you need to validate model data at runtime.
For lightweight type hints without validation, use AvailableModelDict.
"""
provider: str
model: str
class ToolParameter(BaseModel):
"""Tool parameter with runtime validation."""
name: str = ""
type: str = "string"
required: bool = False
human_description: str | dict[str, str] = ""
llm_description: str = ""
options: list[Any] = Field(default_factory=list)
class AvailableTool(BaseModel):
"""
Available tool with runtime validation.
Use this model when you need to validate tool data at runtime.
For lightweight type hints without validation, use AvailableToolDict.
"""
provider_id: str = ""
provider: str = ""
tool_key: str = ""
tool_name: str = ""
tool_description: str = ""
description: str = ""
is_team_authorization: bool = False
parameters: list[ToolParameter] = Field(default_factory=list)
class WorkflowData(BaseModel):
"""
Complete workflow data with runtime validation.
Use this model when you need to validate workflow data at runtime.
For lightweight type hints without validation, use WorkflowDataDict.
"""
nodes: list[WorkflowNode] = Field(default_factory=list)
edges: list[WorkflowEdge] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)

View File

@@ -1,384 +0,0 @@
"""
Edge Repair Utility for Vibe Workflow Generation.
This module provides intelligent edge repair capabilities for generated workflows.
It can detect and fix common edge issues:
- Missing edges between sequential nodes
- Incomplete branches for question-classifier and if-else nodes
- Orphaned nodes without connections
The repair logic is deterministic and doesn't require LLM calls.
"""
import logging
from dataclasses import dataclass, field
from core.workflow.generator.types import WorkflowDataDict, WorkflowEdgeDict, WorkflowNodeDict
logger = logging.getLogger(__name__)
@dataclass
class RepairResult:
"""Result of edge repair operation."""
nodes: list[WorkflowNodeDict]
edges: list[WorkflowEdgeDict]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class EdgeRepair:
"""
Intelligent edge repair for workflow graphs.
Repairs are applied in order:
1. Infer linear connections from node order (if no edges exist)
2. Add missing branch edges for question-classifier
3. Add missing branch edges for if-else
4. Connect orphaned nodes
"""
@classmethod
def repair(cls, workflow_data: WorkflowDataDict) -> RepairResult:
"""
Repair edges in the workflow data.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
Returns:
RepairResult with repaired nodes, edges, and repair logs
"""
nodes = list(workflow_data.get("nodes", []))
edges = list(workflow_data.get("edges", []))
repairs: list[str] = []
warnings: list[str] = []
logger.info("[EDGE REPAIR] Starting repair process for %s nodes, %s edges", len(nodes), len(edges))
# Build node lookup
# Build node lookup
node_map = {n.get("id"): n for n in nodes if n.get("id")}
node_ids = set(node_map.keys())
# 1. If no edges at all, infer linear chain
if not edges and len(nodes) > 1:
edges, inferred_repairs = cls._infer_linear_chain(nodes)
repairs.extend(inferred_repairs)
# 2. Build edge index for analysis
outgoing_edges: dict[str, list[WorkflowEdgeDict]] = {}
incoming_edges: dict[str, list[WorkflowEdgeDict]] = {}
for edge in edges:
src = edge.get("source")
tgt = edge.get("target")
if src:
outgoing_edges.setdefault(src, []).append(edge)
if tgt:
incoming_edges.setdefault(tgt, []).append(edge)
# 3. Repair question-classifier branches
for node in nodes:
if node.get("type") == "question-classifier":
new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 4. Repair if-else branches
for node in nodes:
if node.get("type") == "if-else":
new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 5. Connect orphaned nodes (nodes with no incoming edge, except start)
new_edges, orphan_repairs = cls._connect_orphaned_nodes(nodes, edges, outgoing_edges, incoming_edges)
edges.extend(new_edges)
repairs.extend(orphan_repairs)
# 6. Connect nodes with no outgoing edge to 'end' (except end nodes)
new_edges, terminal_repairs = cls._connect_terminal_nodes(nodes, edges, outgoing_edges)
edges.extend(new_edges)
repairs.extend(terminal_repairs)
if repairs:
logger.info("[EDGE REPAIR] Completed with %s repairs:", len(repairs))
for i, repair in enumerate(repairs, 1):
logger.info("[EDGE REPAIR] %s. %s", i, repair)
else:
logger.info("[EDGE REPAIR] Completed - no repairs needed")
return RepairResult(
nodes=nodes,
edges=edges,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _infer_linear_chain(cls, nodes: list[WorkflowNodeDict]) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Infer a linear chain of edges from node order.
This is used when no edges are provided at all.
"""
edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
# Filter to get ordered node IDs
node_ids = [n.get("id") for n in nodes if n.get("id")]
if len(node_ids) < 2:
return edges, repairs
# Create edges between consecutive nodes
for i in range(len(node_ids) - 1):
src = node_ids[i]
tgt = node_ids[i + 1]
edges.append({"source": src, "target": tgt})
repairs.append(f"Inferred edge: {src} -> {tgt}")
return edges, repairs
@classmethod
def _repair_classifier_branches(
cls,
node: WorkflowNodeDict,
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
valid_node_ids: set[str],
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
"""
Repair missing branches for question-classifier nodes.
For each class that doesn't have an edge, create one pointing to 'end'.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
config = node.get("config", {})
classes = config.get("classes", [])
if not classes:
return new_edges, repairs, warnings
# Get existing sourceHandles for this node
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
# Try to find an end node
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Add missing branches
for cls_def in classes:
if not isinstance(cls_def, dict):
continue
cls_id = cls_def.get("id")
cls_name = cls_def.get("name", cls_id)
if cls_id and cls_id not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": cls_id,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}")
warnings.append(
f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _repair_if_else_branches(
cls,
node: WorkflowNodeDict,
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
valid_node_ids: set[str],
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
"""
Repair missing branches for if-else nodes.
If-else in Dify uses case_id as sourceHandle for each condition,
plus 'false' for the else branch.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
# Get existing sourceHandles
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Get required branches from config
config = node.get("config", {})
cases = config.get("cases", [])
# Build required handles: each case_id + 'false' for else
required_branches = set()
for case in cases:
case_id = case.get("case_id")
if case_id:
required_branches.add(case_id)
required_branches.add("false") # else branch
# Add missing branches
for branch in required_branches:
if branch not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": branch,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing if-else branch '{branch}' -> {end_node_id}")
warnings.append(
f"Auto-connected if-else branch '{branch}' to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _connect_orphaned_nodes(
cls,
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
incoming_edges: dict[str, list[WorkflowEdgeDict]],
) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Connect orphaned nodes to the previous node in sequence.
An orphaned node has no incoming edges and is not a 'start' node.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
node_ids = [n.get("id") for n in nodes if n.get("id")]
node_types = {n.get("id"): n.get("type") for n in nodes}
for i, node_id in enumerate(node_ids):
node_type = node_types.get(node_id)
# Skip start nodes - they don't need incoming edges
if node_type == "start":
continue
# Check if node has incoming edges
if node_id not in incoming_edges or not incoming_edges[node_id]:
# Find previous node to connect from
if i > 0:
prev_node_id = node_ids[i - 1]
new_edge = {"source": prev_node_id, "target": node_id}
new_edges.append(new_edge)
repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}")
# Update incoming_edges for subsequent checks
incoming_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs
@classmethod
def _connect_terminal_nodes(
cls,
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Connect terminal nodes (no outgoing edges) to 'end'.
A terminal node has no outgoing edges and is not an 'end' node.
This ensures all branches eventually reach 'end'.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
# Find end node
end_node_id = None
node_ids = set()
for n in nodes:
nid = n.get("id")
ntype = n.get("type")
if nid:
node_ids.add(nid)
if ntype == "end":
end_node_id = nid
if not end_node_id:
# No end node found, can't connect
return new_edges, repairs
for node in nodes:
node_id = node.get("id")
node_type = node.get("type")
# Skip end nodes
if node_type == "end":
continue
# Skip nodes that already have outgoing edges
if outgoing_edges.get(node_id):
continue
# Connect to end
new_edge = {"source": node_id, "target": end_node_id}
new_edges.append(new_edge)
repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}")
# Update for subsequent checks
outgoing_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs

View File

@@ -1,280 +0,0 @@
"""
Graph Validator for Workflow Generation
Validates workflow graph structure using graph algorithms:
- Reachability from start node (BFS)
- Reachability to end node (reverse BFS)
- Branch edge validation for if-else and classifier nodes
"""
import time
from collections import deque
from dataclasses import dataclass, field
@dataclass
class GraphError:
"""Represents a structural error in the workflow graph."""
node_id: str
node_type: str
error_type: str # "unreachable", "dead_end", "cycle", "missing_start", "missing_end"
message: str
@dataclass
class GraphValidationResult:
"""Result of graph validation."""
success: bool
errors: list[GraphError] = field(default_factory=list)
warnings: list[GraphError] = field(default_factory=list)
execution_time: float = 0.0
stats: dict = field(default_factory=dict)
class GraphValidator:
"""
Validates workflow graph structure using proper graph algorithms.
Performs:
1. Forward reachability analysis (BFS from start)
2. Backward reachability analysis (reverse BFS from end)
3. Branch edge validation for if-else and classifier nodes
"""
@staticmethod
def _build_adjacency(
nodes: dict[str, dict], edges: list[dict]
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""Build forward and reverse adjacency lists from edges."""
outgoing: dict[str, list[str]] = {node_id: [] for node_id in nodes}
incoming: dict[str, list[str]] = {node_id: [] for node_id in nodes}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
if source in outgoing and target in incoming:
outgoing[source].append(target)
incoming[target].append(source)
return outgoing, incoming
@staticmethod
def _bfs_reachable(start: str, adjacency: dict[str, list[str]]) -> set[str]:
"""BFS to find all nodes reachable from start node."""
if start not in adjacency:
return set()
visited = set()
queue = deque([start])
visited.add(start)
while queue:
current = queue.popleft()
for neighbor in adjacency.get(current, []):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return visited
@staticmethod
def validate(workflow_data: dict) -> GraphValidationResult:
"""Validate workflow graph structure."""
start_time = time.time()
errors: list[GraphError] = []
warnings: list[GraphError] = []
nodes_list = workflow_data.get("nodes", [])
edges_list = workflow_data.get("edges", [])
nodes = {n["id"]: n for n in nodes_list if n.get("id")}
# Find start and end nodes
start_node_id = None
end_node_ids = []
for node_id, node in nodes.items():
node_type = node.get("type")
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# Check start node exists
if not start_node_id:
errors.append(
GraphError(
node_id="workflow",
node_type="workflow",
error_type="missing_start",
message="Workflow has no start node",
)
)
# Check end node exists
if not end_node_ids:
errors.append(
GraphError(
node_id="workflow",
node_type="workflow",
error_type="missing_end",
message="Workflow has no end node",
)
)
# If missing start or end, can't do reachability analysis
if not start_node_id or not end_node_ids:
execution_time = time.time() - start_time
return GraphValidationResult(
success=False,
errors=errors,
warnings=warnings,
execution_time=execution_time,
stats={"nodes": len(nodes), "edges": len(edges_list)},
)
# Build adjacency lists
outgoing, incoming = GraphValidator._build_adjacency(nodes, edges_list)
# --- FORWARD REACHABILITY: BFS from start ---
reachable_from_start = GraphValidator._bfs_reachable(start_node_id, outgoing)
# Find unreachable nodes
unreachable_nodes = set(nodes.keys()) - reachable_from_start
for node_id in unreachable_nodes:
node = nodes[node_id]
errors.append(
GraphError(
node_id=node_id,
node_type=node.get("type", "unknown"),
error_type="unreachable",
message=f"Node '{node_id}' is not reachable from start node",
)
)
# --- BACKWARD REACHABILITY: Reverse BFS from end nodes ---
can_reach_end: set[str] = set()
for end_id in end_node_ids:
can_reach_end.update(GraphValidator._bfs_reachable(end_id, incoming))
# Find dead-end nodes (can't reach any end node)
dead_end_nodes = set(nodes.keys()) - can_reach_end
for node_id in dead_end_nodes:
if node_id in unreachable_nodes:
continue
node = nodes[node_id]
warnings.append(
GraphError(
node_id=node_id,
node_type=node.get("type", "unknown"),
error_type="dead_end",
message=f"Node '{node_id}' cannot reach any end node (dead end)",
)
)
# --- Start node has outgoing edges? ---
if not outgoing.get(start_node_id):
errors.append(
GraphError(
node_id=start_node_id,
node_type="start",
error_type="disconnected",
message="Start node has no outgoing connections",
)
)
# --- End nodes have incoming edges? ---
for end_id in end_node_ids:
if not incoming.get(end_id):
errors.append(
GraphError(
node_id=end_id,
node_type="end",
error_type="disconnected",
message="End node has no incoming connections",
)
)
# --- BRANCH EDGE VALIDATION ---
edge_handles: dict[str, set[str]] = {}
for edge in edges_list:
source = edge.get("source")
handle = edge.get("sourceHandle", "")
if source:
if source not in edge_handles:
edge_handles[source] = set()
edge_handles[source].add(handle)
# Check if-else and question-classifier nodes
for node_id, node in nodes.items():
node_type = node.get("type")
if node_type == "if-else":
handles = edge_handles.get(node_id, set())
config = node.get("config", {})
cases = config.get("cases", [])
required_handles = set()
for case in cases:
case_id = case.get("case_id")
if case_id:
required_handles.add(case_id)
required_handles.add("false")
missing = required_handles - handles
for handle in missing:
errors.append(
GraphError(
node_id=node_id,
node_type=node_type,
error_type="missing_branch",
message=f"If-else node '{node_id}' missing edge for branch '{handle}'",
)
)
elif node_type == "question-classifier":
handles = edge_handles.get(node_id, set())
config = node.get("config", {})
classes = config.get("classes", [])
required_handles = set()
for cls in classes:
if isinstance(cls, dict):
cls_id = cls.get("id")
if cls_id:
required_handles.add(cls_id)
missing = required_handles - handles
for handle in missing:
cls_name = handle
for cls in classes:
if isinstance(cls, dict) and cls.get("id") == handle:
cls_name = cls.get("name", handle)
break
errors.append(
GraphError(
node_id=node_id,
node_type=node_type,
error_type="missing_branch",
message=f"Classifier '{node_id}' missing edge for class '{cls_name}'",
)
)
execution_time = time.time() - start_time
success = len(errors) == 0
return GraphValidationResult(
success=success,
errors=errors,
warnings=warnings,
execution_time=execution_time,
stats={
"nodes": len(nodes),
"edges": len(edges_list),
"reachable_from_start": len(reachable_from_start),
"can_reach_end": len(can_reach_end),
"unreachable": len(unreachable_nodes),
"dead_ends": len(dead_end_nodes - unreachable_nodes),
},
)

View File

@@ -1,113 +0,0 @@
import logging
from core.workflow.generator.types import WorkflowDataDict
logger = logging.getLogger(__name__)
def generate_mermaid(workflow_data: WorkflowDataDict) -> str:
"""
Generate a Mermaid flowchart from workflow data consisting of nodes and edges.
Args:
workflow_data: Dict containing 'nodes' (list) and 'edges' (list)
Returns:
String containing the Mermaid flowchart syntax
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
lines = ["flowchart TD"]
# 1. Define Nodes
# Format: node_id["title<br/>type"] or similar
# We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"]
# Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"]
# Map of original IDs to safe Mermaid IDs
id_map = {}
def get_safe_id(original_id: str) -> str:
if original_id == "end":
return "end_node"
if original_id == "subgraph":
return "subgraph_node"
# Mermaid IDs should be alphanumeric.
# If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs.
# We'll trust standard IDs but handle the reserved keyword 'end'.
return original_id
for node in nodes:
node_id = node.get("id")
if not node_id:
continue
safe_id = get_safe_id(node_id)
id_map[node_id] = safe_id
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
# Escape quotes in title
safe_title = title.replace('"', "'")
if node_type == "tool":
config = node.get("config", {})
# Try multiple fields for tool reference
tool_ref = (
config.get("tool_key")
or config.get("tool")
or config.get("tool_name")
or node.get("tool_name")
or "unknown"
)
node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]'
else:
node_def = f'{safe_id}["type={node_type}|title={safe_title}"]'
lines.append(f" {node_def}")
# 2. Define Edges
# Format: source --> target
# Track defined nodes to avoid edge errors
defined_node_ids = {n.get("id") for n in nodes if n.get("id")}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
# Skip invalid edges
if not source or not target:
continue
if source not in defined_node_ids or target not in defined_node_ids:
continue
safe_source = id_map.get(source, source)
safe_target = id_map.get(target, target)
# Handle conditional branches (true/false) if present
# In Dify workflow, sourceHandle is often used for this
source_handle = edge.get("sourceHandle")
label = ""
if source_handle == "true":
label = "|true|"
elif source_handle == "false":
label = "|false|"
elif source_handle and source_handle != "source":
# For question-classifier or other multi-path nodes
# Clean up handle for display if needed
safe_handle = str(source_handle).replace('"', "'")
label = f"|{safe_handle}|"
edge_line = f" {safe_source} -->{label} {safe_target}"
lines.append(edge_line)
# Start/End nodes are implicitly handled if they are in the 'nodes' list
# If not, we might need to add them, but usually the Builder should produce them.
result = "\n".join(lines)
return result

View File

@@ -1,304 +0,0 @@
"""
Node Repair Utility for Vibe Workflow Generation.
This module provides intelligent node configuration repair capabilities.
It can detect and fix common node configuration issues:
- Invalid comparison operators in if-else nodes (e.g. '>=' -> '')
"""
import copy
import logging
import uuid
from dataclasses import dataclass, field
from core.workflow.generator.types import WorkflowNodeDict
logger = logging.getLogger(__name__)
@dataclass
class NodeRepairResult:
"""Result of node repair operation."""
nodes: list[WorkflowNodeDict]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class NodeRepair:
"""
Intelligent node configuration repair.
"""
OPERATOR_MAP = {
">=": "",
"<=": "",
"!=": "",
"==": "=",
}
TYPE_MAPPING = {
"json": "object",
"dict": "object",
"dictionary": "object",
"float": "number",
"int": "number",
"integer": "number",
"double": "number",
"str": "string",
"text": "string",
"bool": "boolean",
"list": "array[object]",
"array": "array[object]",
}
_REPAIR_HANDLERS = {
"if-else": "_repair_if_else_operators",
"variable-aggregator": "_repair_variable_aggregator_variables",
"code": "_repair_code_node_config",
}
@classmethod
def repair(
cls,
nodes: list[WorkflowNodeDict],
llm_callback=None,
) -> NodeRepairResult:
"""
Repair node configurations.
Args:
nodes: List of node dictionaries
llm_callback: Optional callback(node, issue_desc) -> fixed_config_part
Returns:
NodeRepairResult with repaired nodes and logs
"""
# Deep copy to avoid mutating original
nodes = copy.deepcopy(nodes)
repairs: list[str] = []
warnings: list[str] = []
logger.info("[NODE REPAIR] Starting repair process for %s nodes", len(nodes))
for node in nodes:
node_type = node.get("type")
# 1. Rule-based repairs
handler_name = cls._REPAIR_HANDLERS.get(node_type)
if handler_name:
handler = getattr(cls, handler_name)
# Check if handler accepts llm_callback (inspect signature or just pass generic kwargs?)
# Simplest for now: handlers signature: (node, repairs, llm_callback=None)
try:
handler(node, repairs, llm_callback=llm_callback)
except TypeError:
# Fallback for handlers that don't accept llm_callback yet
handler(node, repairs)
# Add other node type repairs here as needed
if repairs:
logger.info("[NODE REPAIR] Completed with %s repairs:", len(repairs))
for i, repair in enumerate(repairs, 1):
logger.info("[NODE REPAIR] %s. %s", i, repair)
else:
logger.info("[NODE REPAIR] Completed - no repairs needed")
return NodeRepairResult(
nodes=nodes,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _repair_if_else_operators(cls, node: WorkflowNodeDict, repairs: list[str], **kwargs):
"""
Normalize comparison operators in if-else nodes.
And ensure 'id' field exists for cases and conditions (frontend requirement).
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
cases = config.get("cases", [])
for case in cases:
# Ensure case_id
if "case_id" not in case:
case["case_id"] = str(uuid.uuid4())
repairs.append(f"Generated missing case_id for case in node '{node_id}'")
conditions = case.get("conditions", [])
for condition in conditions:
# Ensure condition id
if "id" not in condition:
condition["id"] = str(uuid.uuid4())
# Not logging this repair to avoid clutter, as it's a structural fix
# Ensure value type (LLM might return int/float, but we need str/bool/list)
val = condition.get("value")
if isinstance(val, (int, float)) and not isinstance(val, bool):
condition["value"] = str(val)
repairs.append(f"Coerced numeric value to string in node '{node_id}'")
op = condition.get("comparison_operator")
if op in cls.OPERATOR_MAP:
new_op = cls.OPERATOR_MAP[op]
condition["comparison_operator"] = new_op
repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'")
@classmethod
def _repair_variable_aggregator_variables(cls, node: WorkflowNodeDict, repairs: list[str]):
"""
Repair variable-aggregator variables format.
Converts dict format to list[list[str]] format.
Expected: [["node_id", "field"], ["node_id2", "field2"]]
May receive: [{"name": "...", "value_selector": ["node_id", "field"]}, ...]
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
variables = config.get("variables", [])
if not variables:
return
repaired = False
repaired_variables = []
for var in variables:
if isinstance(var, dict):
# Convert dict format to array format
value_selector = var.get("value_selector") or var.get("selector") or var.get("path")
if isinstance(value_selector, list) and len(value_selector) > 0:
repaired_variables.append(value_selector)
repaired = True
else:
# Try to extract from name field - LLM may generate {"name": "node_id.field"}
name = var.get("name")
if isinstance(name, str) and "." in name:
# Try to parse "node_id.field" format
parts = name.split(".", 1)
if len(parts) == 2:
repaired_variables.append([parts[0], parts[1]])
repaired = True
else:
logger.warning(
"Variable aggregator node '%s' has invalid variable format: %s",
node_id,
var,
)
repaired_variables.append([]) # Empty array as fallback
else:
# If no valid selector or name, skip this variable
logger.warning(
"Variable aggregator node '%s' has invalid variable format: %s",
node_id,
var,
)
# Don't add empty array - skip invalid variables
elif isinstance(var, list):
# Already in correct format
repaired_variables.append(var)
else:
# Unknown format, skip
logger.warning("Variable aggregator node '%s' has unknown variable format: %s", node_id, var)
# Don't add empty array - skip invalid variables
if repaired:
config["variables"] = repaired_variables
repairs.append(f"Repaired variable-aggregator variables format in node '{node_id}'")
@classmethod
def _repair_code_node_config(cls, node: WorkflowNodeDict, repairs: list[str], llm_callback=None):
"""
Repair code node configuration (outputs and variables).
1. Outputs: Converts list format to dict format AND normalizes types.
2. Variables: Ensures value_selector exists.
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
if "variables" not in config:
config["variables"] = []
# --- Repair Variables ---
variables = config.get("variables")
if isinstance(variables, list):
for var in variables:
if isinstance(var, dict):
# Ensure value_selector exists (frontend crashes if missing)
if "value_selector" not in var:
var["value_selector"] = []
# Not logging trivial repairs
# --- Repair Outputs ---
outputs = config.get("outputs")
if not outputs:
return
# Helper to normalize type
def normalize_type(t: str) -> str:
t_lower = str(t).lower()
return cls.TYPE_MAPPING.get(t_lower, t)
# 1. Handle Dict format (Standard) - Check for invalid types
if isinstance(outputs, dict):
changed = False
for var_name, var_config in outputs.items():
if isinstance(var_config, dict):
original_type = var_config.get("type")
if original_type:
new_type = normalize_type(original_type)
if new_type != original_type:
var_config["type"] = new_type
changed = True
repairs.append(
f"Normalized type '{original_type}' to '{new_type}' "
f"for var '{var_name}' in node '{node_id}'"
)
return
# 2. Handle List format (Repair needed)
if isinstance(outputs, list):
new_outputs = {}
for item in outputs:
if isinstance(item, dict):
var_name = item.get("variable") or item.get("name")
var_type = item.get("type")
if var_name and var_type:
norm_type = normalize_type(var_type)
new_outputs[var_name] = {"type": norm_type}
if norm_type != var_type:
repairs.append(
f"Normalized type '{var_type}' to '{norm_type}' "
f"during list conversion in node '{node_id}'"
)
if new_outputs:
config["outputs"] = new_outputs
repairs.append(f"Repaired code node outputs format in node '{node_id}'")
else:
# Fallback: Try LLM if available
if llm_callback:
try:
# Attempt to fix using LLM
fixed_outputs = llm_callback(
node,
"outputs must be a dictionary like {'var_name': {'type': 'string'}}, "
"but got a list or valid conversion failed.",
)
if isinstance(fixed_outputs, dict) and fixed_outputs:
config["outputs"] = fixed_outputs
repairs.append(f"Repaired code node outputs format using LLM in node '{node_id}'")
return
except Exception as e:
logger.warning("LLM fallback repair failed for node '%s': %s", node_id, e)
# If conversion/LLM failed, set to empty dict
config["outputs"] = {}
repairs.append(f"Reset invalid code node outputs to empty dict in node '{node_id}'")

View File

@@ -1,101 +0,0 @@
from dataclasses import dataclass
from core.workflow.generator.types import AvailableModelDict, AvailableToolDict, WorkflowDataDict
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine
from core.workflow.generator.validation.rules import Severity
@dataclass
class ValidationHint:
"""Legacy compatibility class for validation hints."""
node_id: str
field: str
message: str
severity: str # 'error', 'warning'
suggestion: str = None
node_type: str = None # Added for test compatibility
# Alias for potential old code using 'type' instead of 'severity'
@property
def type(self) -> str:
return self.severity
@property
def element_id(self) -> str:
return self.node_id
FriendlyHint = ValidationHint # Alias for backward compatibility
class WorkflowValidator:
"""
Validates the generated workflow configuration (nodes and edges).
Wraps the new ValidationEngine for backward compatibility.
"""
@classmethod
def validate(
cls,
workflow_data: WorkflowDataDict,
available_tools: list[AvailableToolDict],
available_models: list[AvailableModelDict] | None = None,
) -> tuple[bool, list[ValidationHint]]:
"""
Validate workflow data and return validity status and hints.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
available_tools: List of available tool configurations
available_models: List of available models (added for Vibe compat)
Returns:
Tuple(max_severity_is_not_error, list_of_hints)
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# Create context
context = ValidationContext(
nodes=nodes,
edges=edges,
available_models=available_models or [],
available_tools=available_tools or [],
)
# Run validation engine
engine = ValidationEngine()
result = engine.validate(context)
# Convert engine errors to legacy hints
hints: list[ValidationHint] = []
error_count = 0
warning_count = 0
for error in result.all_errors:
# Map severity
severity = "error" if error.severity == Severity.ERROR else "warning"
if severity == "error":
error_count += 1
else:
warning_count += 1
# Map field from message or details if possible (heuristic)
field_name = error.details.get("field", "unknown")
hints.append(
ValidationHint(
node_id=error.node_id,
field=field_name,
message=error.message,
severity=severity,
suggestion=error.fix_hint,
node_type=error.node_type,
)
)
return result.is_valid, hints

View File

@@ -1,42 +0,0 @@
"""
Validation Rule Engine for Vibe Workflow Generation.
This module provides a declarative, schema-based validation system for
generated workflow nodes. It classifies errors into fixable (LLM can auto-fix)
and user-required (needs manual intervention) categories.
Usage:
from core.workflow.generator.validation import ValidationEngine, ValidationContext
context = ValidationContext(
available_models=[...],
available_tools=[...],
nodes=[...],
edges=[...],
)
engine = ValidationEngine()
result = engine.validate(context)
# Access classified errors
fixable_errors = result.fixable_errors
user_required_errors = result.user_required_errors
"""
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
ValidationRule,
)
__all__ = [
"RuleCategory",
"Severity",
"ValidationContext",
"ValidationEngine",
"ValidationError",
"ValidationResult",
"ValidationRule",
]

View File

@@ -1,115 +0,0 @@
"""
Validation Context for the Rule Engine.
The ValidationContext holds all the data needed for validation:
- Generated nodes and edges
- Available models, tools, and datasets
- Node output schemas for variable reference validation
"""
from dataclasses import dataclass, field
from core.workflow.generator.types import (
AvailableModelDict,
AvailableToolDict,
WorkflowEdgeDict,
WorkflowNodeDict,
)
@dataclass
class ValidationContext:
"""
Context object containing all data needed for validation.
This is passed to each validation rule, providing access to:
- The nodes being validated
- Edge connections between nodes
- Available external resources (models, tools)
"""
# Generated workflow data
nodes: list[WorkflowNodeDict] = field(default_factory=list)
edges: list[WorkflowEdgeDict] = field(default_factory=list)
# Available external resources
available_models: list[AvailableModelDict] = field(default_factory=list)
available_tools: list[AvailableToolDict] = field(default_factory=list)
# Cached lookups (populated lazily)
_node_map: dict[str, WorkflowNodeDict] | None = field(default=None, repr=False)
_model_set: set[tuple[str, str]] | None = field(default=None, repr=False)
_tool_set: set[str] | None = field(default=None, repr=False)
_configured_tool_set: set[str] | None = field(default=None, repr=False)
@property
def node_map(self) -> dict[str, WorkflowNodeDict]:
"""Get a map of node_id -> node for quick lookup."""
if self._node_map is None:
self._node_map = {node.get("id", ""): node for node in self.nodes}
return self._node_map
@property
def model_set(self) -> set[tuple[str, str]]:
"""Get a set of (provider, model_name) tuples for quick lookup."""
if self._model_set is None:
self._model_set = {(m.get("provider", ""), m.get("model", "")) for m in self.available_models}
return self._model_set
@property
def tool_set(self) -> set[str]:
"""Get a set of all tool keys (both configured and unconfigured)."""
if self._tool_set is None:
self._tool_set = set()
for tool in self.available_tools:
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._tool_set.add(tool_key)
return self._tool_set
@property
def configured_tool_set(self) -> set[str]:
"""Get a set of configured (authorized) tool keys."""
if self._configured_tool_set is None:
self._configured_tool_set = set()
for tool in self.available_tools:
if not tool.get("is_team_authorization", False):
continue
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._configured_tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._configured_tool_set.add(tool_key)
return self._configured_tool_set
def has_model(self, provider: str, model_name: str) -> bool:
"""Check if a model is available."""
return (provider, model_name) in self.model_set
def has_tool(self, tool_key: str) -> bool:
"""Check if a tool exists (configured or not)."""
return tool_key in self.tool_set
def is_tool_configured(self, tool_key: str) -> bool:
"""Check if a tool is configured and ready to use."""
return tool_key in self.configured_tool_set
def get_node(self, node_id: str) -> WorkflowNodeDict | None:
"""Get a node by its ID."""
return self.node_map.get(node_id)
def get_node_ids(self) -> set[str]:
"""Get all node IDs in the workflow."""
return set(self.node_map.keys())
def get_upstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that connect to this node (upstream)."""
return [edge.get("source", "") for edge in self.edges if edge.get("target") == node_id]
def get_downstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that this node connects to (downstream)."""
return [edge.get("target", "") for edge in self.edges if edge.get("source") == node_id]

View File

@@ -1,260 +0,0 @@
"""
Validation Engine - Core validation logic.
The ValidationEngine orchestrates rule execution and aggregates results.
It provides a clean interface for validating workflow nodes.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from core.workflow.generator.types import (
AvailableModelDict,
AvailableToolDict,
WorkflowEdgeDict,
WorkflowNodeDict,
)
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
get_registry,
)
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""
Result of validation containing all errors classified by fixability.
Attributes:
all_errors: All validation errors found
fixable_errors: Errors that LLM can automatically fix
user_required_errors: Errors that require user intervention
warnings: Non-blocking warnings
stats: Validation statistics
"""
all_errors: list[ValidationError] = field(default_factory=list)
fixable_errors: list[ValidationError] = field(default_factory=list)
user_required_errors: list[ValidationError] = field(default_factory=list)
warnings: list[ValidationError] = field(default_factory=list)
stats: dict[str, int] = field(default_factory=dict)
@property
def has_errors(self) -> bool:
"""Check if there are any errors (excluding warnings)."""
return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0
@property
def has_fixable_errors(self) -> bool:
"""Check if there are fixable errors."""
return len(self.fixable_errors) > 0
@property
def is_valid(self) -> bool:
"""Check if validation passed (no errors, warnings are OK)."""
return not self.has_errors
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"fixable": [e.to_dict() for e in self.fixable_errors],
"user_required": [e.to_dict() for e in self.user_required_errors],
"warnings": [e.to_dict() for e in self.warnings],
"all_warnings": [e.message for e in self.all_errors],
"stats": self.stats,
}
def get_error_messages(self) -> list[str]:
"""Get all error messages as strings."""
return [e.message for e in self.all_errors]
def get_fixable_by_node(self) -> dict[str, list[ValidationError]]:
"""Group fixable errors by node ID."""
result: dict[str, list[ValidationError]] = {}
for error in self.fixable_errors:
if error.node_id not in result:
result[error.node_id] = []
result[error.node_id].append(error)
return result
class ValidationEngine:
"""
The main validation engine.
Usage:
engine = ValidationEngine()
context = ValidationContext(nodes=[...], available_models=[...])
result = engine.validate(context)
"""
def __init__(self):
self._registry = get_registry()
def validate(self, context: ValidationContext) -> ValidationResult:
"""
Validate all nodes in the context.
Args:
context: ValidationContext with nodes, edges, and available resources
Returns:
ValidationResult with classified errors
"""
result = ValidationResult()
stats = {
"total_nodes": len(context.nodes),
"total_rules_checked": 0,
"total_errors": 0,
"fixable_count": 0,
"user_required_count": 0,
"warning_count": 0,
}
# Validate each node
for node in context.nodes:
node_type = node.get("type", "unknown")
node_id = node.get("id", "unknown")
# Get applicable rules for this node type
rules = self._registry.get_rules_for_node(node_type)
for rule in rules:
stats["total_rules_checked"] += 1
try:
errors = rule.check(node, context)
for error in errors:
result.all_errors.append(error)
stats["total_errors"] += 1
# Classify by severity and fixability
if error.severity == Severity.WARNING:
result.warnings.append(error)
stats["warning_count"] += 1
elif error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
except Exception:
logger.exception(
"Rule '%s' failed for node '%s'",
rule.id,
node_id,
)
# Don't let a rule failure break the entire validation
continue
# Validate edges separately
edge_errors = self._validate_edges(context)
for error in edge_errors:
result.all_errors.append(error)
stats["total_errors"] += 1
if error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
result.stats = stats
return result
def _validate_edges(self, context: ValidationContext) -> list[ValidationError]:
"""Validate edge connections."""
errors: list[ValidationError] = []
valid_node_ids = context.get_node_ids()
for edge in context.edges:
source = edge.get("source", "")
target = edge.get("target", "")
if source and source not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.source.invalid",
node_id=source,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge source '{source}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
if target and target not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.target.invalid",
node_id=target,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge target '{target}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
return errors
def validate_single_node(
self,
node: WorkflowNodeDict,
context: ValidationContext,
) -> list[ValidationError]:
"""
Validate a single node.
Useful for incremental validation when a node is added/modified.
"""
node_type = node.get("type", "unknown")
rules = self._registry.get_rules_for_node(node_type)
errors: list[ValidationError] = []
for rule in rules:
try:
errors.extend(rule.check(node, context))
except Exception:
logger.exception("Rule '%s' failed", rule.id)
return errors
def validate_nodes(
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict] | None = None,
available_models: list[AvailableModelDict] | None = None,
available_tools: list[AvailableToolDict] | None = None,
) -> ValidationResult:
"""
Convenience function to validate nodes without creating engine/context manually.
Args:
nodes: List of workflow nodes to validate
edges: Optional list of edges
available_models: Optional list of available models
available_tools: Optional list of available tools
Returns:
ValidationResult with classified errors
"""
context = ValidationContext(
nodes=nodes,
edges=edges or [],
available_models=available_models or [],
available_tools=available_tools or [],
)
engine = ValidationEngine()
return engine.validate(context)

View File

@@ -1,947 +0,0 @@
"""
Validation Rules Definition and Registry.
This module defines:
- ValidationRule: The rule structure
- RuleCategory: Categories of validation rules
- Severity: Error severity levels
- ValidationError: Error output structure
- All built-in validation rules
"""
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any
from core.workflow.generator.types import WorkflowNodeDict
if TYPE_CHECKING:
from core.workflow.generator.validation.context import ValidationContext
class RuleCategory(Enum):
"""Categories of validation rules."""
STRUCTURE = "structure" # Field existence, types, formats
SEMANTIC = "semantic" # Variable references, edge connections
REFERENCE = "reference" # External resources (models, tools, datasets)
class Severity(Enum):
"""Severity levels for validation errors."""
ERROR = "error" # Must be fixed
WARNING = "warning" # Should be fixed but not blocking
@dataclass
class ValidationError:
"""
Represents a validation error found during rule execution.
Attributes:
rule_id: The ID of the rule that generated this error
node_id: The ID of the node with the error
node_type: The type of the node
category: The rule category
severity: Error severity
is_fixable: Whether LLM can auto-fix this error
message: Human-readable error message
fix_hint: Hint for LLM to fix the error
details: Additional error details
"""
rule_id: str
node_id: str
node_type: str
category: RuleCategory
severity: Severity
is_fixable: bool
message: str
fix_hint: str = ""
details: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"rule_id": self.rule_id,
"node_id": self.node_id,
"node_type": self.node_type,
"category": self.category.value,
"severity": self.severity.value,
"is_fixable": self.is_fixable,
"message": self.message,
"fix_hint": self.fix_hint,
"details": self.details,
}
# Type alias for rule check functions
RuleCheckFn = Callable[
[WorkflowNodeDict, "ValidationContext"],
list[ValidationError],
]
@dataclass
class ValidationRule:
"""
A validation rule definition.
Attributes:
id: Unique rule identifier (e.g., "llm.model.required")
node_types: List of node types this rule applies to, or ["*"] for all
category: The rule category
severity: Default severity for errors from this rule
is_fixable: Whether errors from this rule can be auto-fixed by LLM
check: The validation function
description: Human-readable description of what this rule checks
fix_hint: Default hint for fixing errors from this rule
"""
id: str
node_types: list[str]
category: RuleCategory
severity: Severity
is_fixable: bool
check: RuleCheckFn
description: str = ""
fix_hint: str = ""
def applies_to(self, node_type: str) -> bool:
"""Check if this rule applies to a given node type."""
return "*" in self.node_types or node_type in self.node_types
# =============================================================================
# Rule Registry
# =============================================================================
class RuleRegistry:
"""
Registry for validation rules.
Rules are registered here and can be retrieved by category or node type.
"""
def __init__(self):
self._rules: list[ValidationRule] = []
def register(self, rule: ValidationRule) -> None:
"""Register a validation rule."""
self._rules.append(rule)
def get_rules_for_node(self, node_type: str) -> list[ValidationRule]:
"""Get all rules that apply to a given node type."""
return [r for r in self._rules if r.applies_to(node_type)]
def get_rules_by_category(self, category: RuleCategory) -> list[ValidationRule]:
"""Get all rules in a given category."""
return [r for r in self._rules if r.category == category]
def get_all_rules(self) -> list[ValidationRule]:
"""Get all registered rules."""
return list(self._rules)
# Global rule registry instance
_registry = RuleRegistry()
def register_rule(rule: ValidationRule) -> ValidationRule:
"""Decorator/function to register a rule with the global registry."""
_registry.register(rule)
return rule
def get_registry() -> RuleRegistry:
"""Get the global rule registry."""
return _registry
# =============================================================================
# Helper Functions for Rule Implementations
# =============================================================================
# Explicit placeholder value defined in prompt contract
# See: api/core/workflow/generator/prompts/vibe_prompts.py
PLACEHOLDER_VALUE = "__PLACEHOLDER__"
# Variable reference pattern: {{#node_id.field#}}
VARIABLE_REF_PATTERN = re.compile(r"\{\{#([^.#]+)\.([^#]+)#\}\}")
def is_placeholder(value: Any) -> bool:
"""Check if a value appears to be a placeholder."""
if not isinstance(value, str):
return False
return value == PLACEHOLDER_VALUE or PLACEHOLDER_VALUE in value
def extract_variable_refs(text: str) -> list[tuple[str, str]]:
"""
Extract variable references from text.
Returns list of (node_id, field_name) tuples.
"""
return VARIABLE_REF_PATTERN.findall(text)
def check_required_field(
config: dict[str, Any],
field_name: str,
node_id: str,
node_type: str,
rule_id: str,
fix_hint: str = "",
) -> ValidationError | None:
"""Helper to check if a required field exists and is non-empty."""
value = config.get(field_name)
if value is None or value == "" or (isinstance(value, list) and len(value) == 0):
return ValidationError(
rule_id=rule_id,
node_id=node_id,
node_type=node_type,
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': missing required field '{field_name}'",
fix_hint=fix_hint or f"Add '{field_name}' to the node config",
)
return None
# =============================================================================
# Structure Rules - Field existence, types, formats
# =============================================================================
def _check_llm_prompt_template(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that LLM node has prompt_template."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"prompt_template",
node_id,
"llm",
"llm.prompt_template.required",
"Add prompt_template with system and user messages",
)
if err:
errors.append(err)
return errors
def _check_http_request_url(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that http-request node has url and method."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
# Check url
url = config.get("url", "")
if not url:
errors.append(
ValidationError(
rule_id="http.url.required",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': http-request missing required 'url'",
fix_hint="Add url - use {{#start.url#}} or a concrete URL",
)
)
elif is_placeholder(url):
errors.append(
ValidationError(
rule_id="http.url.placeholder",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': url contains placeholder value",
fix_hint="Replace placeholder with actual URL or variable reference",
)
)
# Check method
method = config.get("method", "")
if not method:
errors.append(
ValidationError(
rule_id="http.method.required",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': http-request missing 'method'",
fix_hint="Add method: GET, POST, PUT, DELETE, or PATCH",
)
)
return errors
def _check_code_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that code node has code and language."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"code",
node_id,
"code",
"code.code.required",
"Add code with a main() function that returns a dict",
)
if err:
errors.append(err)
err = check_required_field(
config,
"language",
node_id,
"code",
"code.language.required",
"Add language: python3 or javascript",
)
if err:
errors.append(err)
return errors
def _check_question_classifier(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that question-classifier has classes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"classes",
node_id,
"question-classifier",
"classifier.classes.required",
"Add classes array with id and name for each classification",
)
if err:
errors.append(err)
return errors
def _check_parameter_extractor(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that parameter-extractor has parameters and instruction."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"parameters",
node_id,
"parameter-extractor",
"extractor.parameters.required",
"Add parameters array with name, type, description fields",
)
if err:
errors.append(err)
else:
# Check individual parameters for required fields
parameters = config.get("parameters", [])
if isinstance(parameters, list):
for i, param in enumerate(parameters):
if isinstance(param, dict):
# Check for 'required' field (boolean)
if "required" not in param:
errors.append(
ValidationError(
rule_id="extractor.param.required_field.missing",
node_id=node_id,
node_type="parameter-extractor",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': parameter[{i}] missing 'required' field",
fix_hint=f"Add 'required': True to parameter '{param.get('name', 'unknown')}'",
details={"param_index": i, "param_name": param.get("name")},
)
)
# instruction is recommended but not strictly required
if not config.get("instruction"):
errors.append(
ValidationError(
rule_id="extractor.instruction.recommended",
node_id=node_id,
node_type="parameter-extractor",
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
message=f"Node '{node_id}': parameter-extractor should have 'instruction'",
fix_hint="Add instruction describing what to extract",
)
)
return errors
def _check_knowledge_retrieval(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that knowledge-retrieval has dataset_ids."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
dataset_ids = config.get("dataset_ids", [])
if not dataset_ids:
errors.append(
ValidationError(
rule_id="knowledge.dataset.required",
node_id=node_id,
node_type="knowledge-retrieval",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False, # User must select knowledge base
message=f"Node '{node_id}': knowledge-retrieval missing 'dataset_ids'",
fix_hint="User must select knowledge bases in the UI",
)
)
else:
# Check for placeholder values
for ds_id in dataset_ids:
if is_placeholder(ds_id):
errors.append(
ValidationError(
rule_id="knowledge.dataset.placeholder",
node_id=node_id,
node_type="knowledge-retrieval",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False,
message=f"Node '{node_id}': dataset_ids contains placeholder",
fix_hint="User must replace placeholder with actual knowledge base ID",
details={"placeholder_value": ds_id},
)
)
break
return errors
def _check_end_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that end node has outputs defined."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
outputs = config.get("outputs", [])
if not outputs:
errors.append(
ValidationError(
rule_id="end.outputs.recommended",
node_id=node_id,
node_type="end",
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
message="End node should define output variables",
fix_hint="Add outputs array with variable and value_selector",
)
)
return errors
# =============================================================================
# Semantic Rules - Variable references, edge connections
# =============================================================================
def _check_variable_references(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that variable references point to valid nodes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
config = node.get("config", {})
# Get all valid node IDs (including 'start' which is always valid)
valid_node_ids = ctx.get_node_ids()
valid_node_ids.add("start")
valid_node_ids.add("sys") # System variables
def check_text_for_refs(text: str, field_path: str) -> None:
if not isinstance(text, str):
return
refs = extract_variable_refs(text)
for ref_node_id, ref_field in refs:
if ref_node_id not in valid_node_ids:
errors.append(
ValidationError(
rule_id="variable.ref.invalid_node",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': references non-existent node '{ref_node_id}'",
fix_hint=f"Change {{{{#{ref_node_id}.{ref_field}#}}}} to reference a valid node",
details={"field_path": field_path, "invalid_ref": ref_node_id},
)
)
# Check prompt_template for LLM nodes
prompt_template = config.get("prompt_template", [])
if isinstance(prompt_template, list):
for i, msg in enumerate(prompt_template):
if isinstance(msg, dict):
text = msg.get("text", "")
check_text_for_refs(text, f"prompt_template[{i}].text")
# Check instruction field
instruction = config.get("instruction", "")
check_text_for_refs(instruction, "instruction")
# Check url for http-request
url = config.get("url", "")
check_text_for_refs(url, "url")
return errors
# NOTE: _check_node_has_outgoing_edge removed - handled by GraphValidator
# NOTE: _check_node_has_incoming_edge removed - handled by GraphValidator
# NOTE: _check_question_classifier_branches removed - handled by EdgeRepair
# NOTE: _check_if_else_branches removed - handled by EdgeRepair
def _check_if_else_operators(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that if-else comparison operators are valid."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
if node_type != "if-else":
return errors
valid_operators = {
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
"all of",
"=",
"",
">",
"<",
"",
"",
"null",
"not null",
"exists",
"not exists",
}
config = node.get("config", {})
cases = config.get("cases", [])
for case in cases:
conditions = case.get("conditions", [])
for condition in conditions:
op = condition.get("comparison_operator")
if op and op not in valid_operators:
errors.append(
ValidationError(
rule_id="ifelse.operator.invalid",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Invalid operator '{op}' in if-else node",
fix_hint=f"Use one of: {', '.join(sorted(valid_operators))}",
details={"invalid_operator": op, "field": "config.cases.conditions.comparison_operator"},
)
)
return errors
def _check_edge_targets_exist(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that edge targets reference existing nodes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
valid_node_ids = ctx.get_node_ids()
# Check all outgoing edges from this node
for edge in ctx.edges:
if edge.get("source") == node_id:
target = edge.get("target")
if target and target not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.target.invalid",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge from '{node_id}' targets non-existent node '{target}'",
fix_hint=f"Change edge target from '{target}' to an existing node",
details={"invalid_target": target, "field": "edges"},
)
)
return errors
# =============================================================================
# Reference Rules - External resources (models, tools, datasets)
# =============================================================================
# Node types that require model configuration
MODEL_REQUIRED_NODE_TYPES = {"llm", "question-classifier", "parameter-extractor"}
def _check_model_config(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that model configuration is valid."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
config = node.get("config", {})
if node_type not in MODEL_REQUIRED_NODE_TYPES:
return errors
model = config.get("model")
# Check if model config exists
if not model:
if ctx.available_models:
errors.append(
ValidationError(
rule_id="model.required",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}' ({node_type}): missing required 'model' configuration",
fix_hint="Add model config using one of the available models",
)
)
else:
errors.append(
ValidationError(
rule_id="model.no_available",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=False,
message=f"Node '{node_id}' ({node_type}): needs model but no models available",
fix_hint="User must configure a model provider first",
)
)
return errors
# Check if model config is valid
if isinstance(model, dict):
provider = model.get("provider", "")
name = model.get("name", "")
# Check for placeholder values
if is_placeholder(provider) or is_placeholder(name):
if ctx.available_models:
errors.append(
ValidationError(
rule_id="model.placeholder",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': model config contains placeholder",
fix_hint="Replace placeholder with actual model from available_models",
)
)
return errors
# Check if model exists in available_models
if ctx.available_models and provider and name:
if not ctx.has_model(provider, name):
errors.append(
ValidationError(
rule_id="model.not_found",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': model '{provider}/{name}' not in available models",
fix_hint="Replace with a model from available_models",
details={"provider": provider, "model": name},
)
)
return errors
def _check_tool_reference(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that tool references are valid and configured."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
if node_type != "tool":
return errors
config = node.get("config", {})
tool_ref = (
config.get("tool_key")
or config.get("tool_name")
or config.get("provider_id", "") + "/" + config.get("tool_name", "")
)
if not tool_ref:
errors.append(
ValidationError(
rule_id="tool.key.required",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': tool node missing tool_key",
fix_hint="Add tool_key from available_tools",
)
)
return errors
# Check if tool exists
if not ctx.has_tool(tool_ref):
errors.append(
ValidationError(
rule_id="tool.not_found",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True, # Can be replaced with http-request fallback
message=f"Node '{node_id}': tool '{tool_ref}' not found",
fix_hint="Use http-request or code node as fallback",
details={"tool_ref": tool_ref},
)
)
elif not ctx.is_tool_configured(tool_ref):
errors.append(
ValidationError(
rule_id="tool.not_configured",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.WARNING,
is_fixable=False, # User needs to configure
message=f"Node '{node_id}': tool '{tool_ref}' requires configuration",
fix_hint="Configure the tool in Tools settings",
details={"tool_ref": tool_ref},
)
)
return errors
# =============================================================================
# Register All Rules
# =============================================================================
# Structure Rules
register_rule(
ValidationRule(
id="llm.prompt_template.required",
node_types=["llm"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_llm_prompt_template,
description="LLM node must have prompt_template",
fix_hint="Add prompt_template with system and user messages",
)
)
register_rule(
ValidationRule(
id="http.config.required",
node_types=["http-request"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_http_request_url,
description="HTTP request node must have url and method",
fix_hint="Add url and method to config",
)
)
register_rule(
ValidationRule(
id="code.config.required",
node_types=["code"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_code_node,
description="Code node must have code and language",
fix_hint="Add code with main() function and language",
)
)
register_rule(
ValidationRule(
id="classifier.classes.required",
node_types=["question-classifier"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_question_classifier,
description="Question classifier must have classes",
fix_hint="Add classes array with classification options",
)
)
register_rule(
ValidationRule(
id="extractor.config.required",
node_types=["parameter-extractor"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_parameter_extractor,
description="Parameter extractor must have parameters",
fix_hint="Add parameters array",
)
)
register_rule(
ValidationRule(
id="knowledge.config.required",
node_types=["knowledge-retrieval"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False,
check=_check_knowledge_retrieval,
description="Knowledge retrieval must have dataset_ids",
fix_hint="User must select knowledge base",
)
)
register_rule(
ValidationRule(
id="end.outputs.check",
node_types=["end"],
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
check=_check_end_node,
description="End node should have outputs",
fix_hint="Add outputs array",
)
)
# Semantic Rules
register_rule(
ValidationRule(
id="variable.references.valid",
node_types=["*"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_variable_references,
description="Variable references must point to valid nodes",
fix_hint="Fix variable reference to use valid node ID",
)
)
# Edge Validation Rules
# NOTE: Edge connectivity and branch completeness are now handled by:
# - GraphValidator (BFS-based reachability analysis)
# - EdgeRepair (automatic branch edge repair)
register_rule(
ValidationRule(
id="edge.targets.valid",
node_types=["*"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_edge_targets_exist,
description="Edge targets must reference existing nodes",
fix_hint="Change edge target to an existing node ID",
)
)
# Reference Rules
register_rule(
ValidationRule(
id="model.config.valid",
node_types=["llm", "question-classifier", "parameter-extractor"],
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_model_config,
description="Model configuration must be valid",
fix_hint="Add valid model from available_models",
)
)
register_rule(
ValidationRule(
id="tool.reference.valid",
node_types=["tool"],
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_tool_reference,
description="Tool reference must be valid and configured",
fix_hint="Use valid tool or fallback node",
)
)
register_rule(
ValidationRule(
id="ifelse.operator.valid",
node_types=["if-else"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_if_else_operators,
description="If-else operators must be valid",
fix_hint="Use standard operators like ≥, ≤, =, ≠",
)
)

View File

@@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@@ -113,6 +113,8 @@ class RedisChannel:
return AbortCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
if command_type == CommandType.UPDATE_VARIABLES:
return UpdateVariablesCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)

View File

@@ -5,11 +5,12 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
"UpdateVariablesCommandHandler",
]

View File

@@ -4,9 +4,10 @@ from typing import final
from typing_extensions import override
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.runtime import VariablePool
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler):
reason = command.reason
pause_reason = SchedulingPause(message=reason)
execution.pause(pause_reason)
@final
class UpdateVariablesCommandHandler(CommandHandler):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, UpdateVariablesCommand)
for update in command.updates:
try:
variable = update.value
self._variable_pool.add(variable.selector, variable)
logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id)
except ValueError as exc:
logger.warning(
"Skipping invalid variable selector %s for workflow %s: %s",
getattr(update.value, "selector", None),
execution.workflow_id,
exc,
)

View File

@@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine
instance to control its execution flow.
"""
from enum import StrEnum
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
class CommandType(StrEnum):
"""Types of commands that can be sent to GraphEngine."""
ABORT = "abort"
PAUSE = "pause"
ABORT = auto()
PAUSE = auto()
UPDATE_VARIABLES = auto()
class GraphEngineCommand(BaseModel):
@@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str = Field(default="unknown reason", description="reason for pause")
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):
"""Command to update a group of variables in the variable pool."""
command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command")
updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates")

View File

@@ -30,8 +30,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .command_processing import (
AbortCommandHandler,
CommandProcessor,
PauseCommandHandler,
UpdateVariablesCommandHandler,
)
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@@ -140,6 +145,9 @@ class GraphEngine:
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[GraphEngineLayer] = []
@@ -212,9 +220,16 @@ class GraphEngine:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
@@ -301,14 +316,7 @@ class GraphEngine:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:

View File

@@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.
Abstract base class for layers.
- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
@@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```
`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.
## Custom Layers
```python

View File

@@ -13,6 +13,14 @@ from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""
def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")
class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
@@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):
def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None
@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state
def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.
Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.
Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel
@abstractmethod

View File

@@ -109,10 +109,8 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)
if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")
@override
def on_event(self, event: GraphEngineEvent) -> None:
@@ -243,8 +241,7 @@ class DebugLoggingLayer(GraphEngineLayer):
self.logger.info(" Node retries: %s", self.retry_count)
# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
self.logger.info("=" * 80)

View File

@@ -337,8 +337,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
@@ -404,6 +402,4 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel.
This module provides a simplified interface for controlling workflow executions
using the new Redis command channel, without requiring user permission checks.
Supports stop, pause, and resume operations.
"""
import logging
from collections.abc import Sequence
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
GraphEngineCommand,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -23,7 +29,6 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop and pause operations.
"""
@staticmethod
@@ -45,6 +50,16 @@ class GraphEngineManager:
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
"""Send a command to update variables in a running workflow."""
if not updates:
return
update_command = UpdateVariablesCommand(updates=updates)
GraphEngineManager._send_command(task_id, update_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""

View File

@@ -197,14 +197,6 @@ class Node(Generic[NodeDataT]):
return None
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
"""
Get the default configuration schema for the node.
Used for LLM generation.
"""
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}

View File

@@ -1,8 +1,7 @@
from collections.abc import Mapping, Sequence
from decimal import Decimal
from typing import Any, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
@@ -13,6 +12,7 @@ from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.code.limits import CodeNodeLimits
from .exc import (
CodeNodeError,
@@ -20,9 +20,41 @@ from .exc import (
OutputValidationError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
class CodeNode(Node[CodeNodeData]):
node_type = NodeType.CODE
_DEFAULT_CODE_PROVIDERS: ClassVar[tuple[type[CodeNodeProvider], ...]] = (
Python3CodeProvider,
JavascriptCodeProvider,
)
_limits: CodeNodeLimits
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else self._DEFAULT_CODE_PROVIDERS
)
self._limits = code_limits
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -35,11 +67,16 @@ class CodeNode(Node[CodeNodeData]):
if filters:
code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3))
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] = next(p for p in providers if p.is_accept_language(code_language))
code_provider: type[CodeNodeProvider] = next(
provider for provider in cls._DEFAULT_CODE_PROVIDERS if provider.is_accept_language(code_language)
)
return code_provider.get_default_config()
@classmethod
def default_code_providers(cls) -> tuple[type[CodeNodeProvider], ...]:
return cls._DEFAULT_CODE_PROVIDERS
@classmethod
def version(cls) -> str:
return "1"
@@ -60,7 +97,8 @@ class CodeNode(Node[CodeNodeData]):
variables[variable_name] = variable.to_object() if variable else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
_ = self._select_code_provider(code_language)
result = self._code_executor.execute_workflow_code_template(
language=code_language,
code=code,
inputs=variables,
@@ -75,6 +113,12 @@ class CodeNode(Node[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _select_code_provider(self, code_language: CodeLanguage) -> type[CodeNodeProvider]:
for provider in self._code_providers:
if provider.is_accept_language(code_language):
return provider
raise CodeNodeError(f"Unsupported code language: {code_language}")
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
@@ -85,10 +129,10 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
if len(value) > self._limits.max_string_length:
raise OutputValidationError(
f"The length of output variable `{variable}` must be"
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
f" less than {self._limits.max_string_length} characters"
)
return value.replace("\x00", "")
@@ -109,20 +153,20 @@ class CodeNode(Node[CodeNodeData]):
if value is None:
return None
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
if value > self._limits.max_number or value < self._limits.min_number:
raise OutputValidationError(
f"Output variable `{variable}` is out of range,"
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
f" it must be between {self._limits.min_number} and {self._limits.max_number}."
)
if isinstance(value, float):
decimal_value = Decimal(str(value)).normalize()
precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
if precision > dify_config.CODE_MAX_PRECISION:
if precision > self._limits.max_precision:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
f" it must be less than {self._limits.max_precision} digits."
)
return value
@@ -137,8 +181,8 @@ class CodeNode(Node[CodeNodeData]):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
if depth > self._limits.max_depth:
raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.")
transformed_result: dict[str, Any] = {}
if output_schema is None:
@@ -272,10 +316,10 @@ class CodeNode(Node[CodeNodeData]):
f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead."
)
else:
if len(value) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
if len(value) > self._limits.max_number_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
f" less than {self._limits.max_number_array_length} elements."
)
for i, inner_value in enumerate(value):
@@ -305,10 +349,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_string_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
f" less than {self._limits.max_string_array_length} elements."
)
transformed_result[output_name] = [
@@ -326,10 +370,10 @@ class CodeNode(Node[CodeNodeData]):
f" got {type(result.get(output_name))} instead."
)
else:
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
if len(result[output_name]) > self._limits.max_object_array_length:
raise OutputValidationError(
f"The length of output variable `{prefix}{dot}{output_name}` must be"
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
f" less than {self._limits.max_object_array_length} elements."
)
for i, value in enumerate(result[output_name]):

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class CodeNodeLimits:
max_string_length: int
max_number: int | float
min_number: int | float
max_precision: int
max_depth: int
max_number_array_length: int
max_string_array_length: int
max_object_array_length: int

View File

@@ -1,5 +1,3 @@
from typing import Any
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -11,24 +9,6 @@ class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow exit point - defines output variables",
"required": ["outputs"],
"parameters": {
"outputs": {
"type": "array",
"description": "Output variables to return",
"item_schema": {
"variable": "string - output variable name",
"type": "enum: string, number, object, array",
"value_selector": "array - path to source value, e.g. ['node_id', 'field']",
},
},
},
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -1,10 +1,21 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.workflow.enums import NodeType
from core.workflow.graph import NodeFactory
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
@@ -27,9 +38,29 @@ class DifyNodeFactory(NodeFactory):
self,
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
code_executor: type[CodeExecutor] | None = None,
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
self.graph_init_params = graph_init_params
self.graph_runtime_state = graph_runtime_state
self._code_executor: type[CodeExecutor] = code_executor or CodeExecutor
self._code_providers: tuple[type[CodeNodeProvider], ...] = (
tuple(code_providers) if code_providers else CodeNode.default_code_providers()
)
self._code_limits = code_limits or CodeNodeLimits(
max_string_length=dify_config.CODE_MAX_STRING_LENGTH,
max_number=dify_config.CODE_MAX_NUMBER,
min_number=dify_config.CODE_MIN_NUMBER,
max_precision=dify_config.CODE_MAX_PRECISION,
max_depth=dify_config.CODE_MAX_DEPTH,
max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH,
max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH,
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@override
def create_node(self, node_config: dict[str, object]) -> Node:
@@ -72,6 +103,26 @@ class DifyNodeFactory(NodeFactory):
raise ValueError(f"No latest version class found for node type: {node_type}")
# Create node instance
if node_type == NodeType.CODE:
return CodeNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
code_executor=self._code_executor,
code_providers=self._code_providers,
code_limits=self._code_limits,
)
if node_type == NodeType.TEMPLATE_TRANSFORM:
return TemplateTransformNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
)
return node_class(
id=node_id,
config=node_config,

View File

@@ -15,27 +15,6 @@ class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow entry point - defines input variables",
"required": [],
"parameters": {
"variables": {
"type": "array",
"description": "Input variables for the workflow",
"item_schema": {
"variable": "string - variable name",
"label": "string - display label",
"type": "enum: text-input, paragraph, number, select, file, file-list",
"required": "boolean",
"max_length": "number (optional)",
},
},
},
"outputs": ["All defined variables are available as {{#start.variable_name#}}"],
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
class TemplateRenderError(ValueError):
"""Raised when rendering a Jinja2 template fails."""
class Jinja2TemplateRenderer(Protocol):
"""Render Jinja2 templates for template transform nodes."""
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
"""Render a Jinja2 template with provided variables."""
raise NotImplementedError
class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer):
"""Adapter that renders Jinja2 templates via CodeExecutor."""
_code_executor: type[CodeExecutor]
def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None:
self._code_executor = code_executor or CodeExecutor
def render_template(self, template: str, variables: Mapping[str, Any]) -> str:
try:
result = self._code_executor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=template, inputs=variables
)
except CodeExecutionError as exc:
raise TemplateRenderError(str(exc)) from exc
rendered = result.get("result")
if not isinstance(rendered, str):
raise TemplateRenderError("Template render result must be a string.")
return rendered

View File

@@ -1,18 +1,44 @@
from collections.abc import Mapping, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
Jinja2TemplateRenderer,
TemplateRenderError,
)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -39,13 +65,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
)
except CodeExecutionError as e:
rendered = self._template_renderer.render_template(self.node_data.template, variables)
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
@@ -53,7 +77,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": result["result"]}
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered}
)
@classmethod

View File

@@ -50,19 +50,6 @@ class ToolNode(Node[ToolNodeData]):
def version(cls) -> str:
return "1"
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Execute an external tool",
"required": ["provider_id", "tool_id", "tool_parameters"],
"parameters": {
"provider_id": {"type": "string"},
"provider_type": {"type": "string"},
"tool_id": {"type": "string"},
"tool_parameters": {"type": "object"},
},
}
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node

View File

@@ -46,7 +46,11 @@ def _get_celery_ssl_options() -> dict[str, Any] | None:
def init_app(app: DifyApp) -> Celery:
class FlaskTask(Task):
def __call__(self, *args: object, **kwargs: object) -> object:
from core.logging.context import init_request_context
with app.app_context():
# Initialize logging context for this task (similar to before_request in Flask)
init_request_context()
return self.run(*args, **kwargs)
broker_transport_options = {}

View File

@@ -11,6 +11,7 @@ def init_app(app: DifyApp):
create_tenant,
extract_plugins,
extract_unique_plugins,
file_usage,
fix_app_site_missing,
install_plugins,
install_rag_pipeline_plugins,
@@ -47,6 +48,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
file_usage,
setup_system_tool_oauth_client,
setup_system_trigger_oauth_client,
cleanup_orphaned_draft_variables,

View File

@@ -53,3 +53,10 @@ def _setup_gevent_compatibility():
def init_app(app: DifyApp):
db.init_app(app)
_setup_gevent_compatibility()
# Eagerly build the engine so pool_size/max_overflow/etc. come from config
try:
with app.app_context():
_ = db.engine # triggers engine creation with the configured options
except Exception:
logger.exception("Failed to initialize SQLAlchemy engine during app startup")

View File

@@ -1,18 +1,19 @@
"""Logging extension for Dify Flask application."""
import logging
import os
import sys
import uuid
from logging.handlers import RotatingFileHandler
import flask
from configs import dify_config
from core.helper.trace_id_helper import get_trace_id_from_otel_context
from dify_app import DifyApp
def init_app(app: DifyApp):
"""Initialize logging with support for text or JSON format."""
log_handlers: list[logging.Handler] = []
# File handler
log_file = dify_config.LOG_FILE
if log_file:
log_dir = os.path.dirname(log_file)
@@ -25,27 +26,53 @@ def init_app(app: DifyApp):
)
)
# Always add StreamHandler to log to console
# Console handler
sh = logging.StreamHandler(sys.stdout)
log_handlers.append(sh)
# Apply RequestIdFilter to all handlers
for handler in log_handlers:
handler.addFilter(RequestIdFilter())
# Apply filters to all handlers
from core.logging.filters import IdentityContextFilter, TraceContextFilter
for handler in log_handlers:
handler.addFilter(TraceContextFilter())
handler.addFilter(IdentityContextFilter())
# Configure formatter based on format type
formatter = _create_formatter()
for handler in log_handlers:
handler.setFormatter(formatter)
# Configure root logger
logging.basicConfig(
level=dify_config.LOG_LEVEL,
format=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
handlers=log_handlers,
force=True,
)
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False
# Apply timezone if specified (only for text format)
if dify_config.LOG_OUTPUT_FORMAT == "text":
_apply_timezone(log_handlers)
def _create_formatter() -> logging.Formatter:
"""Create appropriate formatter based on configuration."""
if dify_config.LOG_OUTPUT_FORMAT == "json":
from core.logging.structured_formatter import StructuredJSONFormatter
return StructuredJSONFormatter()
else:
# Text format - use existing pattern with backward compatible formatter
return _TextFormatter(
fmt=dify_config.LOG_FORMAT,
datefmt=dify_config.LOG_DATEFORMAT,
)
def _apply_timezone(handlers: list[logging.Handler]):
"""Apply timezone conversion to text formatters."""
log_tz = dify_config.LOG_TZ
if log_tz:
from datetime import datetime
@@ -57,34 +84,51 @@ def init_app(app: DifyApp):
def time_converter(seconds):
return datetime.fromtimestamp(seconds, tz=timezone).timetuple()
for handler in logging.root.handlers:
for handler in handlers:
if handler.formatter:
handler.formatter.converter = time_converter
handler.formatter.converter = time_converter # type: ignore[attr-defined]
def get_request_id():
if getattr(flask.g, "request_id", None):
return flask.g.request_id
class _TextFormatter(logging.Formatter):
"""Text formatter that ensures trace_id and req_id are always present."""
new_uuid = uuid.uuid4().hex[:10]
flask.g.request_id = new_uuid
return new_uuid
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
record.trace_id = ""
if not hasattr(record, "span_id"):
record.span_id = ""
return super().format(record)
def get_request_id() -> str:
"""Get request ID for current request context.
Deprecated: Use core.logging.context.get_request_id() directly.
"""
from core.logging.context import get_request_id as _get_request_id
return _get_request_id()
# Backward compatibility aliases
class RequestIdFilter(logging.Filter):
# This is a logging filter that makes the request ID available for use in
# the logging format. Note that we're checking if we're in a request
# context, as we may want to log things before Flask is fully loaded.
def filter(self, record):
trace_id = get_trace_id_from_otel_context() or ""
record.req_id = get_request_id() if flask.has_request_context() else ""
record.trace_id = trace_id
"""Deprecated: Use TraceContextFilter from core.logging.filters instead."""
def filter(self, record: logging.LogRecord) -> bool:
from core.logging.context import get_request_id as _get_request_id
from core.logging.context import get_trace_id as _get_trace_id
record.req_id = _get_request_id()
record.trace_id = _get_trace_id()
return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
"""Deprecated: Use _TextFormatter instead."""
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "req_id"):
record.req_id = ""
if not hasattr(record, "trace_id"):
@@ -93,6 +137,7 @@ class RequestIdFormatter(logging.Formatter):
def apply_request_id_formatter():
"""Deprecated: Formatter is now applied in init_app."""
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

View File

@@ -22,6 +22,18 @@ from models.enums import WorkflowRunTriggeredFrom
logger = logging.getLogger(__name__)
def to_serializable(obj):
"""
Convert non-JSON-serializable objects into JSON-compatible formats.
- Uses `to_dict()` if it's a callable method.
- Falls back to string representation.
"""
if hasattr(obj, "to_dict") and callable(obj.to_dict):
return obj.to_dict()
return str(obj)
class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
def __init__(
self,
@@ -69,6 +81,11 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
# Set to True to enable dual-write for safe migration, False to use LogStore only
self._enable_dual_write = os.environ.get("LOGSTORE_DUAL_WRITE_ENABLED", "true").lower() == "true"
# Control flag for whether to write the `graph` field to LogStore.
# If LOGSTORE_ENABLE_PUT_GRAPH_FIELD is "true", write the full `graph` field;
# otherwise write an empty {} instead. Defaults to writing the `graph` field.
self._enable_put_graph_field = os.environ.get("LOGSTORE_ENABLE_PUT_GRAPH_FIELD", "true").lower() == "true"
def _to_logstore_model(self, domain_model: WorkflowExecution) -> list[tuple[str, str]]:
"""
Convert a domain model to a logstore model (List[Tuple[str, str]]).
@@ -108,9 +125,24 @@ class LogstoreWorkflowExecutionRepository(WorkflowExecutionRepository):
),
("type", domain_model.workflow_type.value),
("version", domain_model.workflow_version),
("graph", json.dumps(domain_model.graph, ensure_ascii=False) if domain_model.graph else "{}"),
("inputs", json.dumps(domain_model.inputs, ensure_ascii=False) if domain_model.inputs else "{}"),
("outputs", json.dumps(domain_model.outputs, ensure_ascii=False) if domain_model.outputs else "{}"),
(
"graph",
json.dumps(domain_model.graph, ensure_ascii=False, default=to_serializable)
if domain_model.graph and self._enable_put_graph_field
else "{}",
),
(
"inputs",
json.dumps(domain_model.inputs, ensure_ascii=False, default=to_serializable)
if domain_model.inputs
else "{}",
),
(
"outputs",
json.dumps(domain_model.outputs, ensure_ascii=False, default=to_serializable)
if domain_model.outputs
else "{}",
),
("status", domain_model.status.value),
("error_message", domain_model.error_message or ""),
("total_tokens", str(domain_model.total_tokens)),

View File

@@ -19,26 +19,43 @@ logger = logging.getLogger(__name__)
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
Unlike creating a new span, this records exceptions on the existing span
to maintain trace context consistency throughout the request lifecycle.
"""
def emit(self, record: logging.LogRecord):
with contextlib.suppress(Exception):
if record.exc_info:
tracer = get_tracer_provider().get_tracer("dify.exception.logging")
with tracer.start_as_current_span(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
) as span:
span.set_status(StatusCode.ERROR)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
span.set_attribute("exception.message", str(record.exc_info[1]))
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
if not record.exc_info:
return
from opentelemetry.trace import get_current_span
span = get_current_span()
if not span or not span.is_recording():
return
# Record exception on the current span instead of creating a new one
span.set_status(StatusCode.ERROR, record.getMessage())
# Add log context as span events/attributes
span.add_event(
"log.exception",
attributes={
"log.level": record.levelname,
"log.message": record.getMessage(),
"log.logger": record.name,
"log.file.path": record.pathname,
"log.file.line": record.lineno,
},
)
if record.exc_info[1]:
span.record_exception(record.exc_info[1])
if record.exc_info[0]:
span.set_attribute("exception.type", record.exc_info[0].__name__)
def instrument_exception_logging() -> None:

View File

@@ -1,236 +1,338 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
from datetime import datetime
from typing import Any, TypeAlias
from .raws import FilesContainedField
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.file import File
JSONValue: TypeAlias = Any
class MessageTextField(fields.Raw):
def format(self, value):
return value[0]["text"] if value else ""
class ResponseModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
extra="ignore",
populate_by_name=True,
serialize_by_alias=True,
protected_namespaces=(),
)
feedback_fields = {
"rating": fields.String,
"content": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account": fields.Nested(simple_account_fields, allow_null=True),
}
class MessageFile(ResponseModel):
id: str
filename: str
type: str
url: str | None = None
mime_type: str | None = None
size: int | None = None
transfer_method: str
belongs_to: str | None = None
upload_file_id: str | None = None
annotation_fields = {
"id": fields.String,
"question": fields.String,
"content": fields.String,
"account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
annotation_hit_history_fields = {
"annotation_id": fields.String(attribute="id"),
"annotation_create_account": fields.Nested(simple_account_fields, allow_null=True),
"created_at": TimestampField,
}
message_file_fields = {
"id": fields.String,
"filename": fields.String,
"type": fields.String,
"url": fields.String,
"mime_type": fields.String,
"size": fields.Integer,
"transfer_method": fields.String,
"belongs_to": fields.String(default="user"),
"upload_file_id": fields.String(default=None),
}
@field_validator("transfer_method", mode="before")
@classmethod
def _normalize_transfer_method(cls, value: object) -> str:
if isinstance(value, str):
return value
return str(value)
def build_message_file_model(api_or_ns: Namespace):
"""Build the message file fields for the API or Namespace."""
return api_or_ns.model("MessageFile", message_file_fields)
class SimpleConversation(ResponseModel):
id: str
name: str
inputs: dict[str, JSONValue]
status: str
introduction: str | None = None
created_at: int | None = None
updated_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", "updated_at", mode="before")
@classmethod
def _normalize_timestamp(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
message_detail_fields = {
"id": fields.String,
"conversation_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"message": fields.Raw,
"message_tokens": fields.Integer,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"answer_tokens": fields.Integer,
"provider_response_latency": fields.Float,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"feedbacks": fields.List(fields.Nested(feedback_fields)),
"workflow_run_id": fields.String,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"annotation_hit_history": fields.Nested(annotation_hit_history_fields, allow_null=True),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
}
feedback_stat_fields = {"like": fields.Integer, "dislike": fields.Integer}
status_count_fields = {"success": fields.Integer, "failed": fields.Integer, "partial_success": fields.Integer}
model_config_fields = {
"opening_statement": fields.String,
"suggested_questions": fields.Raw,
"model": fields.Raw,
"user_input_form": fields.Raw,
"pre_prompt": fields.String,
"agent_mode": fields.Raw,
}
simple_model_config_fields = {
"model": fields.Raw(attribute="model_dict"),
"pre_prompt": fields.String,
}
simple_message_detail_fields = {
"inputs": FilesContainedField,
"query": fields.String,
"message": MessageTextField,
"answer": fields.String,
}
conversation_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String(),
"from_account_id": fields.String,
"from_account_name": fields.String,
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotation": fields.Nested(annotation_fields, allow_null=True),
"model_config": fields.Nested(simple_model_config_fields),
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"message": fields.Nested(simple_message_detail_fields, attribute="first_message"),
}
conversation_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_fields), attribute="items"),
}
conversation_message_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"model_config": fields.Nested(model_config_fields),
"message": fields.Nested(message_detail_fields, attribute="first_message"),
}
conversation_with_summary_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_end_user_session_id": fields.String,
"from_account_id": fields.String,
"from_account_name": fields.String,
"name": fields.String,
"summary": fields.String(attribute="summary_or_query"),
"read_at": TimestampField,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"model_config": fields.Nested(simple_model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
"status_count": fields.Nested(status_count_fields),
}
conversation_with_summary_pagination_fields = {
"page": fields.Integer,
"limit": fields.Integer(attribute="per_page"),
"total": fields.Integer,
"has_more": fields.Boolean(attribute="has_next"),
"data": fields.List(fields.Nested(conversation_with_summary_fields), attribute="items"),
}
conversation_detail_fields = {
"id": fields.String,
"status": fields.String,
"from_source": fields.String,
"from_end_user_id": fields.String,
"from_account_id": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"annotated": fields.Boolean,
"introduction": fields.String,
"model_config": fields.Nested(model_config_fields),
"message_count": fields.Integer,
"user_feedback_stats": fields.Nested(feedback_stat_fields),
"admin_feedback_stats": fields.Nested(feedback_stat_fields),
}
simple_conversation_fields = {
"id": fields.String,
"name": fields.String,
"inputs": FilesContainedField,
"status": fields.String,
"introduction": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
conversation_delete_fields = {
"result": fields.String,
}
conversation_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(simple_conversation_fields)),
}
class ConversationInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SimpleConversation]
def build_conversation_infinite_scroll_pagination_model(api_or_ns: Namespace):
"""Build the conversation infinite scroll pagination model for the API or Namespace."""
simple_conversation_model = build_simple_conversation_model(api_or_ns)
copied_fields = conversation_infinite_scroll_pagination_fields.copy()
copied_fields["data"] = fields.List(fields.Nested(simple_conversation_model))
return api_or_ns.model("ConversationInfiniteScrollPagination", copied_fields)
class ConversationDelete(ResponseModel):
result: str
def build_conversation_delete_model(api_or_ns: Namespace):
"""Build the conversation delete model for the API or Namespace."""
return api_or_ns.model("ConversationDelete", conversation_delete_fields)
class ResultResponse(ResponseModel):
result: str
def build_simple_conversation_model(api_or_ns: Namespace):
"""Build the simple conversation model for the API or Namespace."""
return api_or_ns.model("SimpleConversation", simple_conversation_fields)
class SimpleAccount(ResponseModel):
id: str
name: str
email: str
class Feedback(ResponseModel):
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account: SimpleAccount | None = None
class Annotation(ResponseModel):
id: str
question: str | None = None
content: str
account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AnnotationHitHistory(ResponseModel):
annotation_id: str
annotation_create_account: SimpleAccount | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class AgentThought(ResponseModel):
id: str
chain_id: str | None = None
message_chain_id: str | None = Field(default=None, exclude=True, validation_alias="message_chain_id")
message_id: str
position: int
thought: str | None = None
tool: str | None = None
tool_labels: JSONValue
tool_input: str | None = None
created_at: int | None = None
observation: str | None = None
files: list[str]
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
@model_validator(mode="after")
def _fallback_chain_id(self):
if self.chain_id is None and self.message_chain_id:
self.chain_id = self.message_chain_id
return self
class MessageDetail(ResponseModel):
id: str
conversation_id: str
inputs: dict[str, JSONValue]
query: str
message: JSONValue
message_tokens: int
answer: str
answer_tokens: int
provider_response_latency: float
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
feedbacks: list[Feedback]
workflow_run_id: str | None = None
annotation: Annotation | None = None
annotation_hit_history: AnnotationHitHistory | None = None
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
metadata: JSONValue
status: str
error: str | None = None
parent_message_id: str | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class FeedbackStat(ResponseModel):
like: int
dislike: int
class StatusCount(ResponseModel):
success: int
failed: int
partial_success: int
class ModelConfig(ResponseModel):
opening_statement: str | None = None
suggested_questions: JSONValue | None = None
model: JSONValue | None = None
user_input_form: JSONValue | None = None
pre_prompt: str | None = None
agent_mode: JSONValue | None = None
class SimpleModelConfig(ResponseModel):
model: JSONValue | None = None
pre_prompt: str | None = None
class SimpleMessageDetail(ResponseModel):
inputs: dict[str, JSONValue]
query: str
message: str
answer: str
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValue) -> JSONValue:
return format_files_contained(value)
class Conversation(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotation: Annotation | None = None
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
message: SimpleMessageDetail | None = None
class ConversationPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[Conversation]
class ConversationMessageDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message: MessageDetail | None = None
class ConversationWithSummary(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_end_user_session_id: str | None = None
from_account_id: str | None = None
from_account_name: str | None = None
name: str
summary: str
read_at: int | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
model_config_: SimpleModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
status_count: StatusCount | None = None
class ConversationWithSummaryPagination(ResponseModel):
page: int
limit: int
total: int
has_more: bool
data: list[ConversationWithSummary]
class ConversationDetail(ResponseModel):
id: str
status: str
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: int | None = None
updated_at: int | None = None
annotated: bool
introduction: str | None = None
model_config_: ModelConfig | None = Field(default=None, alias="model_config")
message_count: int
user_feedback_stats: FeedbackStat | None = None
admin_feedback_stats: FeedbackStat | None = None
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValue) -> JSONValue:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value
def message_text(value: JSONValue) -> str:
if isinstance(value, list) and value:
first = value[0]
if isinstance(first, dict):
text = first.get("text")
if isinstance(text, str):
return text
return ""
def extract_model_config(value: object | None) -> dict[str, JSONValue]:
if value is None:
return {}
if isinstance(value, dict):
return value
if hasattr(value, "to_dict"):
return value.to_dict()
return {}

View File

@@ -1,77 +1,137 @@
from flask_restx import Namespace, fields
from __future__ import annotations
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
from datetime import datetime
from typing import TypeAlias
from .raws import FilesContainedField
from pydantic import BaseModel, ConfigDict, Field, field_validator
feedback_fields = {
"rating": fields.String,
}
from core.file import File
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
JSONValueType: TypeAlias = JSONValue
def build_feedback_model(api_or_ns: Namespace):
"""Build the feedback model for the API or Namespace."""
return api_or_ns.model("Feedback", feedback_fields)
class ResponseModel(BaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
agent_thought_fields = {
"id": fields.String,
"chain_id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"thought": fields.String,
"tool": fields.String,
"tool_labels": fields.Raw,
"tool_input": fields.String,
"created_at": TimestampField,
"observation": fields.String,
"files": fields.List(fields.String),
}
class SimpleFeedback(ResponseModel):
rating: str | None = None
def build_agent_thought_model(api_or_ns: Namespace):
"""Build the agent thought model for the API or Namespace."""
return api_or_ns.model("AgentThought", agent_thought_fields)
class RetrieverResource(ResponseModel):
id: str
message_id: str
position: int
dataset_id: str | None = None
dataset_name: str | None = None
document_id: str | None = None
document_name: str | None = None
data_source_type: str | None = None
segment_id: str | None = None
score: float | None = None
hit_count: int | None = None
word_count: int | None = None
segment_position: int | None = None
index_node_hash: str | None = None
content: str | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
retriever_resource_fields = {
"id": fields.String,
"message_id": fields.String,
"position": fields.Integer,
"dataset_id": fields.String,
"dataset_name": fields.String,
"document_id": fields.String,
"document_name": fields.String,
"data_source_type": fields.String,
"segment_id": fields.String,
"score": fields.Float,
"hit_count": fields.Integer,
"word_count": fields.Integer,
"segment_position": fields.Integer,
"index_node_hash": fields.String,
"content": fields.String,
"created_at": TimestampField,
}
class MessageListItem(ResponseModel):
id: str
conversation_id: str
parent_message_id: str | None = None
inputs: dict[str, JSONValueType]
query: str
answer: str = Field(validation_alias="re_sign_file_url_answer")
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
retriever_resources: list[RetrieverResource]
created_at: int | None = None
agent_thoughts: list[AgentThought]
message_files: list[MessageFile]
status: str
error: str | None = None
message_fields = {
"id": fields.String,
"conversation_id": fields.String,
"parent_message_id": fields.String,
"inputs": FilesContainedField,
"query": fields.String,
"answer": fields.String(attribute="re_sign_file_url_answer"),
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_fields)),
"message_files": fields.List(fields.Nested(message_file_fields)),
"status": fields.String,
"error": fields.String,
}
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
message_infinite_scroll_pagination_fields = {
"limit": fields.Integer,
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(message_fields)),
}
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class WebMessageListItem(MessageListItem):
metadata: JSONValueType | None = Field(default=None, validation_alias="message_metadata_dict")
class MessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[MessageListItem]
class WebMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[WebMessageListItem]
class SavedMessageItem(ResponseModel):
id: str
inputs: dict[str, JSONValueType]
query: str
answer: str
message_files: list[MessageFile]
feedback: SimpleFeedback | None = Field(default=None, validation_alias="user_feedback")
created_at: int | None = None
@field_validator("inputs", mode="before")
@classmethod
def _normalize_inputs(cls, value: JSONValueType) -> JSONValueType:
return format_files_contained(value)
@field_validator("created_at", mode="before")
@classmethod
def _normalize_created_at(cls, value: datetime | int | None) -> int | None:
if isinstance(value, datetime):
return to_timestamp(value)
return value
class SavedMessageInfiniteScrollPagination(ResponseModel):
limit: int
has_more: bool
data: list[SavedMessageItem]
class SuggestedQuestionsResponse(ResponseModel):
data: list[str]
def to_timestamp(value: datetime | None) -> int | None:
if value is None:
return None
return int(value.timestamp())
def format_files_contained(value: JSONValueType) -> JSONValueType:
if isinstance(value, File):
return value.model_dump()
if isinstance(value, dict):
return {k: format_files_contained(v) for k, v in value.items()}
if isinstance(value, list):
return [format_files_contained(v) for v in value]
return value

View File

@@ -1,5 +1,4 @@
import re
import sys
from collections.abc import Mapping
from typing import Any
@@ -109,11 +108,8 @@ def register_external_error_handlers(api: Api):
data.setdefault("code", "unknown")
data.setdefault("status", status_code)
# Log stack
exc_info: Any = sys.exc_info()
if exc_info[1] is None:
exc_info = (None, None, None)
current_app.log_exception(exc_info)
# Note: Exception logging is handled by Flask/Flask-RESTX framework automatically
# Explicit log_exception call removed to avoid duplicate log entries
return data, status_code

View File

@@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '00bacef91f18'
down_revision = '8ec536f3c800'
@@ -23,31 +20,17 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', sa.Text(), nullable=False))
batch_op.drop_column('description_str')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description', models.types.LongText(), nullable=False))
batch_op.drop_column('description_str')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', sa.TEXT(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('description_str', models.types.LongText(), autoincrement=False, nullable=False))
batch_op.drop_column('description')
# ### end Alembic commands ###

View File

@@ -7,14 +7,10 @@ Create Date: 2024-01-10 04:40:57.257824
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '114eed84c228'
down_revision = 'c71211c8f604'
@@ -32,13 +28,7 @@ def upgrade():
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', postgresql.UUID(), autoincrement=False, nullable=False))
else:
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_id', models.types.StringUUID(), autoincrement=False, nullable=False))
# ### end Alembic commands ###

View File

@@ -11,9 +11,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '161cadc1af8d'
down_revision = '7e6a8693e07a'
@@ -23,16 +20,9 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', sa.UUID(), nullable=False))
else:
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
with op.batch_alter_table('dataset_permissions', schema=None) as batch_op:
# Step 1: Add column without NOT NULL constraint
op.add_column('dataset_permissions', sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
# ### end Alembic commands ###

View File

@@ -9,11 +9,6 @@ from alembic import op
import models.types
def _is_pg(conn):
return conn.dialect.name == "postgresql"
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '6af6a521a53e'
down_revision = 'd57ba9ebb251'
@@ -23,58 +18,30 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=True)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=True)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=True)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=sa.UUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=sa.UUID(),
nullable=False)
else:
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
with op.batch_alter_table('dataset_retriever_resources', schema=None) as batch_op:
batch_op.alter_column('segment_id',
existing_type=models.types.StringUUID(),
nullable=False)
batch_op.alter_column('data_source_type',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('document_id',
existing_type=models.types.StringUUID(),
nullable=False)
# ### end Alembic commands ###

View File

@@ -8,7 +8,6 @@ Create Date: 2024-11-01 04:34:23.816198
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'd3f6769a94a3'

View File

@@ -28,85 +28,45 @@ def upgrade():
op.execute("UPDATE sites SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
op.execute("UPDATE tool_api_providers SET custom_disclaimer = '' WHERE custom_disclaimer IS NULL")
if _is_pg(conn):
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=sa.TEXT(),
nullable=False)
else:
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.VARCHAR(length=255),
type_=models.types.LongText(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=255),
nullable=True)
else:
with op.batch_alter_table('tool_api_providers', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('sites', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
with op.batch_alter_table('recommended_apps', schema=None) as batch_op:
batch_op.alter_column('custom_disclaimer',
existing_type=models.types.LongText(),
type_=sa.VARCHAR(length=255),
nullable=True)
# ### end Alembic commands ###

View File

@@ -49,57 +49,33 @@ def upgrade():
op.execute("UPDATE workflows SET updated_at = created_at WHERE updated_at IS NULL")
op.execute("UPDATE workflows SET graph = '' WHERE graph IS NULL")
op.execute("UPDATE workflows SET features = '' WHERE features IS NULL")
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=False)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=False)
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
if _is_pg(conn):
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=postgresql.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=sa.TEXT(),
nullable=True)
batch_op.alter_column('graph',
existing_type=sa.TEXT(),
nullable=True)
else:
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('updated_at',
existing_type=sa.TIMESTAMP(),
nullable=True)
batch_op.alter_column('features',
existing_type=models.types.LongText(),
nullable=True)
batch_op.alter_column('graph',
existing_type=models.types.LongText(),
nullable=True)
if _is_pg(conn):
with op.batch_alter_table('messages', schema=None) as batch_op:

Some files were not shown because too many files have changed in this diff Show More