Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
d7e33788be chore(deps): bump the python-packages group in /api with 2 updates
Updates the requirements on [sentry-sdk](https://github.com/getsentry/sentry-python) and [unstructured](https://github.com/Unstructured-IO/unstructured) to permit the latest version.

Updates `sentry-sdk` to 2.56.0
- [Release notes](https://github.com/getsentry/sentry-python/releases)
- [Changelog](https://github.com/getsentry/sentry-python/blob/master/CHANGELOG.md)
- [Commits](https://github.com/getsentry/sentry-python/compare/2.55.0...2.56.0)

Updates `unstructured` to 0.22.6
- [Release notes](https://github.com/Unstructured-IO/unstructured/releases)
- [Changelog](https://github.com/Unstructured-IO/unstructured/blob/main/CHANGELOG.md)
- [Commits](https://github.com/Unstructured-IO/unstructured/compare/0.21.5...0.22.6)

---
updated-dependencies:
- dependency-name: sentry-sdk
  dependency-version: 2.56.0
  dependency-type: direct:production
  dependency-group: python-packages
- dependency-name: unstructured
  dependency-version: 0.22.6
  dependency-type: direct:production
  dependency-group: python-packages
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-03-30 09:48:44 +00:00
343 changed files with 11900 additions and 12997 deletions

View File

@@ -6,6 +6,7 @@ runs:
- name: Setup Vite+
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
with:
working-directory: web
node-version-file: .nvmrc
cache: true
run-install: true

View File

@@ -39,10 +39,6 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
- name: Check api inputs
if: github.event_name != 'merge_group'
id: api-changes

View File

@@ -24,39 +24,27 @@ env:
jobs:
build:
runs-on: ${{ matrix.runs_on }}
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
if: github.repository == 'langgenius/dify'
strategy:
matrix:
include:
- service_name: "build-api-amd64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-api-arm64"
image_name_env: "DIFY_API_IMAGE_NAME"
artifact_context: "api"
build_context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
- service_name: "build-web-amd64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
platform: linux/amd64
runs_on: ubuntu-latest
- service_name: "build-web-arm64"
image_name_env: "DIFY_WEB_IMAGE_NAME"
artifact_context: "web"
build_context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
steps:
- name: Prepare
@@ -70,6 +58,9 @@ jobs:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@@ -83,8 +74,7 @@ jobs:
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: ${{ matrix.build_context }}
file: ${{ matrix.file }}
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
@@ -103,7 +93,7 @@ jobs:
- name: Upload digest
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
if-no-files-found: error
retention-days: 1

View File

@@ -6,12 +6,7 @@ on:
- "main"
paths:
- api/Dockerfile
- web/docker/**
- web/Dockerfile
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
- .nvmrc
concurrency:
group: docker-build-${{ github.head_ref || github.run_id }}
@@ -19,31 +14,26 @@ concurrency:
jobs:
build-docker:
runs-on: ${{ matrix.runs_on }}
runs-on: ubuntu-latest
strategy:
matrix:
include:
- service_name: "api-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
- service_name: "api-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}:api"
file: "Dockerfile"
context: "api"
- service_name: "web-amd64"
platform: linux/amd64
runs_on: ubuntu-latest
context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
- service_name: "web-arm64"
platform: linux/arm64
runs_on: ubuntu-24.04-arm
context: "{{defaultContext}}"
file: "web/Dockerfile"
context: "web"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
@@ -51,8 +41,8 @@ jobs:
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: ${{ matrix.context }}
file: ${{ matrix.file }}
context: "{{defaultContext}}:${{ matrix.context }}"
file: "${{ matrix.file }}"
platforms: ${{ matrix.platform }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -65,10 +65,6 @@ jobs:
- 'docker/volumes/sandbox/conf/**'
web:
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
e2e:
@@ -77,10 +73,6 @@ jobs:
- 'api/uv.lock'
- 'e2e/**'
- 'web/**'
- 'package.json'
- 'pnpm-lock.yaml'
- 'pnpm-workspace.yaml'
- '.nvmrc'
- 'docker/docker-compose.middleware.yaml'
- 'docker/middleware.env.example'
- '.github/workflows/web-e2e.yml'

View File

@@ -50,17 +50,6 @@ jobs:
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Check if line counts match
id: line_count_check
run: |
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
if [ "$base_lines" -eq "$pr_lines" ]; then
echo "same=true" >> $GITHUB_OUTPUT
else
echo "same=false" >> $GITHUB_OUTPUT
fi
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
@@ -74,7 +63,7 @@ jobs:
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -77,10 +77,6 @@ jobs:
with:
files: |
web/**
package.json
pnpm-lock.yaml
pnpm-workspace.yaml
.nvmrc
.github/workflows/style.yml
.github/actions/setup-web/**
@@ -94,9 +90,9 @@ jobs:
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
with:
path: web/.eslintcache
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
restore-keys: |
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -6,9 +6,6 @@ on:
- main
paths:
- sdks/**
- package.json
- pnpm-lock.yaml
- pnpm-workspace.yaml
concurrency:
group: sdk-tests-${{ github.head_ref || github.run_id }}

View File

@@ -1,10 +1,10 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are bridged by trigger-i18n-sync.yml via repository_dispatch.
on:
repository_dispatch:
types: [i18n-sync]
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
inputs:
files:
@@ -30,7 +30,7 @@ permissions:
concurrency:
group: translate-i18n-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: false
cancel-in-progress: ${{ github.event_name == 'push' }}
jobs:
translate:
@@ -67,113 +67,19 @@ jobs:
}
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
generate_changes_json() {
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const currentJson = readCurrentJson(fileStem)
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = currentJson || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: currentJson === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
}
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
BASE_SHA="${{ github.event.client_payload.base_sha }}"
HEAD_SHA="${{ github.event.client_payload.head_sha }}"
CHANGED_FILES="${{ github.event.client_payload.changed_files }}"
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
SYNC_MODE="${{ github.event.client_payload.sync_mode || 'incremental' }}"
if [ -n "${{ github.event.client_payload.changes_base64 }}" ]; then
printf '%s' '${{ github.event.client_payload.changes_base64 }}' | base64 -d > /tmp/i18n-changes.json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="embedded"
elif [ -n "$BASE_SHA" ] && [ -n "$CHANGED_FILES" ]; then
export BASE_SHA HEAD_SHA CHANGED_FILES
generate_changes_json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="recomputed"
else
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
CHANGES_AVAILABLE="false"
CHANGES_SOURCE="unavailable"
if [ "${{ github.event_name }}" = "push" ]; then
BASE_SHA="${{ github.event.before }}"
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
fi
HEAD_SHA="${{ github.sha }}"
if [ -n "$BASE_SHA" ]; then
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
else
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
fi
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
SYNC_MODE="incremental"
else
BASE_SHA=""
HEAD_SHA=$(git rev-parse HEAD)
@@ -198,17 +104,6 @@ jobs:
else
CHANGED_FILES=""
fi
if [ "$SYNC_MODE" = "incremental" ] && [ -n "$CHANGED_FILES" ]; then
export BASE_SHA HEAD_SHA CHANGED_FILES
generate_changes_json
CHANGES_AVAILABLE="true"
CHANGES_SOURCE="local"
else
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
CHANGES_AVAILABLE="false"
CHANGES_SOURCE="unavailable"
fi
fi
FILE_ARGS=""
@@ -228,8 +123,6 @@ jobs:
echo "CHANGED_FILES=$CHANGED_FILES"
echo "TARGET_LANGS=$TARGET_LANGS"
echo "SYNC_MODE=$SYNC_MODE"
echo "CHANGES_AVAILABLE=$CHANGES_AVAILABLE"
echo "CHANGES_SOURCE=$CHANGES_SOURCE"
echo "FILE_ARGS=$FILE_ARGS"
echo "LANG_ARGS=$LANG_ARGS"
} >> "$GITHUB_OUTPUT"
@@ -248,7 +141,7 @@ jobs:
show_full_output: ${{ github.event_name == 'workflow_dispatch' }}
prompt: |
You are the i18n sync agent for the Dify repository.
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`.
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result.
Use absolute paths at all times:
- Repo root: `${{ github.workspace }}`
@@ -263,15 +156,12 @@ jobs:
- Head SHA: `${{ steps.context.outputs.HEAD_SHA }}`
- Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}`
- Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}`
- Structured change set available: `${{ steps.context.outputs.CHANGES_AVAILABLE }}`
- Structured change set source: `${{ steps.context.outputs.CHANGES_SOURCE }}`
- Structured change set file: `/tmp/i18n-changes.json`
Tool rules:
- Use Read for repository files.
- Use Edit for JSON updates.
- Use Bash only for `pnpm`.
- Do not use Bash for `git`, `gh`, or branch management.
- Use Bash only for `git`, `gh`, `pnpm`, and `date`.
- Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution.
Required execution plan:
1. Resolve target languages.
@@ -282,25 +172,27 @@ jobs:
- Only process the resolved target languages, never `en-US`.
- Do not touch unrelated i18n files.
- Do not modify `${{ github.workspace }}/web/i18n/en-US/`.
3. Resolve source changes.
- If `Structured change set available` is `true`, read `/tmp/i18n-changes.json` and use it as the source of truth for file-level and key-level changes.
- For each file entry:
- `added` contains new English keys that need translations.
- `updated` contains stale keys whose English source changed; re-translate using the `after` value.
- `deleted` contains keys that should be removed from locale files.
- `fileDeleted: true` means the English file no longer exists; remove the matching locale file if present.
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
3. Detect English changes per file.
- Read the current English JSON file for each file in scope.
- If sync mode is `incremental` and `Base SHA` is not empty, run:
`git -C ${{ github.workspace }} show <Base SHA>:web/i18n/en-US/<file>.json`
- If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync.
- If the file did not exist at Base SHA, treat all current keys as ADD.
- Compare previous and current English JSON to identify:
- ADD: key only in current
- UPDATE: key exists in both and the English value changed
- DELETE: key only in previous
- Do not rely on a truncated diff file.
4. Run a scoped pre-check before editing:
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- Use this command as the source of truth for missing and extra keys inside the current scope.
5. Apply translations.
- For every target language and scoped file:
- If `fileDeleted` is `true`, remove the locale file if it exists and skip the rest of that file.
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
- ADD missing keys.
- UPDATE stale translations when the English value changed.
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
- For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation.
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
- Match the existing terminology and register used by each locale.
- Prefer one Edit per file when stable, but prioritize correctness over batching.
@@ -308,119 +200,14 @@ jobs:
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
- If verification fails, fix the remaining problems before continuing.
7. Stop after the scoped locale files are updated and verification passes.
- Do not create branches, commits, or pull requests.
7. Create a PR only when there are changes in `web/i18n/`.
- Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/`
- Create branch `chore/i18n-sync-<timestamp>`
- Commit message: `chore(i18n): sync translations with en-US`
- Push the branch and open a PR against `main`
- PR title: `chore(i18n): sync translations with en-US`
- PR body: summarize files, languages, sync mode, and verification commands
8. If there are no translation changes after verification, do not create a branch, commit, or PR.
claude_args: |
--max-turns 120
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
- name: Prepare branch metadata
id: pr_meta
if: steps.context.outputs.CHANGED_FILES != ''
shell: bash
run: |
if [ -z "$(git -C "${{ github.workspace }}" status --porcelain -- web/i18n/)" ]; then
echo "has_changes=false" >> "$GITHUB_OUTPUT"
exit 0
fi
SCOPE_HASH=$(printf '%s|%s|%s' "${{ steps.context.outputs.CHANGED_FILES }}" "${{ steps.context.outputs.TARGET_LANGS }}" "${{ steps.context.outputs.SYNC_MODE }}" | sha256sum | cut -c1-8)
HEAD_SHORT=$(printf '%s' "${{ steps.context.outputs.HEAD_SHA }}" | cut -c1-12)
BRANCH_NAME="chore/i18n-sync-${HEAD_SHORT}-${SCOPE_HASH}"
{
echo "has_changes=true"
echo "branch_name=$BRANCH_NAME"
} >> "$GITHUB_OUTPUT"
- name: Commit translation changes
if: steps.pr_meta.outputs.has_changes == 'true'
shell: bash
run: |
git -C "${{ github.workspace }}" checkout -B "${{ steps.pr_meta.outputs.branch_name }}"
git -C "${{ github.workspace }}" add web/i18n/
git -C "${{ github.workspace }}" commit -m "chore(i18n): sync translations with en-US"
- name: Push translation branch
if: steps.pr_meta.outputs.has_changes == 'true'
shell: bash
run: |
if git -C "${{ github.workspace }}" ls-remote --exit-code --heads origin "${{ steps.pr_meta.outputs.branch_name }}" >/dev/null 2>&1; then
git -C "${{ github.workspace }}" push --force-with-lease origin "${{ steps.pr_meta.outputs.branch_name }}"
else
git -C "${{ github.workspace }}" push --set-upstream origin "${{ steps.pr_meta.outputs.branch_name }}"
fi
- name: Create or update translation PR
if: steps.pr_meta.outputs.has_changes == 'true'
env:
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
SYNC_MODE: ${{ steps.context.outputs.SYNC_MODE }}
CHANGES_SOURCE: ${{ steps.context.outputs.CHANGES_SOURCE }}
BASE_SHA: ${{ steps.context.outputs.BASE_SHA }}
HEAD_SHA: ${{ steps.context.outputs.HEAD_SHA }}
REPO_NAME: ${{ github.repository }}
shell: bash
run: |
PR_BODY_FILE=/tmp/i18n-pr-body.md
LANG_COUNT=$(printf '%s\n' "$TARGET_LANGS" | wc -w | tr -d ' ')
if [ "$LANG_COUNT" = "0" ]; then
LANG_COUNT="0"
fi
export LANG_COUNT
node <<'NODE' > "$PR_BODY_FILE"
const fs = require('node:fs')
const changesPath = '/tmp/i18n-changes.json'
const changes = fs.existsSync(changesPath)
? JSON.parse(fs.readFileSync(changesPath, 'utf8'))
: { changes: {} }
const filesInScope = (process.env.FILES_IN_SCOPE || '').split(/\s+/).filter(Boolean)
const lines = [
'## Summary',
'',
`- **Files synced**: \`${process.env.FILES_IN_SCOPE || '<none>'}\``,
`- **Languages updated**: ${process.env.TARGET_LANGS || '<none>'} (${process.env.LANG_COUNT} languages)`,
`- **Sync mode**: ${process.env.SYNC_MODE}${process.env.BASE_SHA ? ` (base: \`${process.env.BASE_SHA.slice(0, 10)}\`, head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)` : ` (head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)`}`,
'',
'### Key changes',
]
for (const fileName of filesInScope) {
const fileChange = changes.changes?.[fileName] || { added: {}, updated: {}, deleted: [], fileDeleted: false }
const addedKeys = Object.keys(fileChange.added || {})
const updatedKeys = Object.keys(fileChange.updated || {})
const deletedKeys = fileChange.deleted || []
lines.push(`- \`${fileName}\`: +${addedKeys.length} / ~${updatedKeys.length} / -${deletedKeys.length}${fileChange.fileDeleted ? ' (file deleted in en-US)' : ''}`)
}
lines.push(
'',
'## Verification',
'',
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
'',
'## Notes',
'',
'- This PR was generated from structured en-US key changes produced by `trigger-i18n-sync.yml`.',
`- Structured change source: ${process.env.CHANGES_SOURCE || 'unknown'}.`,
'- Branch name is deterministic for the head SHA and scope, so reruns update the same PR instead of opening duplicates.',
'',
'🤖 Generated with [Claude Code](https://claude.com/claude-code)'
)
process.stdout.write(lines.join('\n'))
NODE
EXISTING_PR_NUMBER=$(gh pr list --repo "$REPO_NAME" --head "$BRANCH_NAME" --state open --json number --jq '.[0].number')
if [ -n "$EXISTING_PR_NUMBER" ] && [ "$EXISTING_PR_NUMBER" != "null" ]; then
gh pr edit "$EXISTING_PR_NUMBER" --repo "$REPO_NAME" --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
else
gh pr create --repo "$REPO_NAME" --head "$BRANCH_NAME" --base main --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
fi
--max-turns 80
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

View File

@@ -1,171 +0,0 @@
name: Trigger i18n Sync on Push
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
concurrency:
group: trigger-i18n-sync-${{ github.ref }}
cancel-in-progress: true
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Detect changed files and build structured change set
id: detect
shell: bash
run: |
BASE_SHA="${{ github.event.before }}"
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
fi
HEAD_SHA="${{ github.sha }}"
if [ -n "$BASE_SHA" ]; then
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
else
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
fi
export BASE_SHA HEAD_SHA CHANGED_FILES
node <<'NODE'
const { execFileSync } = require('node:child_process')
const fs = require('node:fs')
const path = require('node:path')
const repoRoot = process.cwd()
const baseSha = process.env.BASE_SHA || ''
const headSha = process.env.HEAD_SHA || ''
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
const readCurrentJson = (fileStem) => {
const filePath = englishPath(fileStem)
if (!fs.existsSync(filePath))
return null
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
}
const readBaseJson = (fileStem) => {
if (!baseSha)
return null
try {
const relativePath = `web/i18n/en-US/${fileStem}.json`
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
return JSON.parse(content)
}
catch (error) {
return null
}
}
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
const changes = {}
for (const fileStem of files) {
const beforeJson = readBaseJson(fileStem) || {}
const afterJson = readCurrentJson(fileStem) || {}
const added = {}
const updated = {}
const deleted = []
for (const [key, value] of Object.entries(afterJson)) {
if (!(key in beforeJson)) {
added[key] = value
continue
}
if (!compareJson(beforeJson[key], value)) {
updated[key] = {
before: beforeJson[key],
after: value,
}
}
}
for (const key of Object.keys(beforeJson)) {
if (!(key in afterJson))
deleted.push(key)
}
changes[fileStem] = {
fileDeleted: readCurrentJson(fileStem) === null,
added,
updated,
deleted,
}
}
fs.writeFileSync(
'/tmp/i18n-changes.json',
JSON.stringify({
baseSha,
headSha,
files,
changes,
})
)
NODE
if [ -n "$CHANGED_FILES" ]; then
echo "has_changes=true" >> "$GITHUB_OUTPUT"
else
echo "has_changes=false" >> "$GITHUB_OUTPUT"
fi
echo "base_sha=$BASE_SHA" >> "$GITHUB_OUTPUT"
echo "head_sha=$HEAD_SHA" >> "$GITHUB_OUTPUT"
echo "changed_files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
env:
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
CHANGED_FILES: ${{ steps.detect.outputs.changed_files }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const fs = require('fs')
const changesJson = fs.readFileSync('/tmp/i18n-changes.json', 'utf8')
const changesBase64 = Buffer.from(changesJson).toString('base64')
const maxEmbeddedChangesChars = 48000
const changesEmbedded = changesBase64.length <= maxEmbeddedChangesChars
if (!changesEmbedded) {
console.log(`Structured change set too large to embed safely (${changesBase64.length} chars). Downstream workflow will regenerate it from git history.`)
}
await github.rest.repos.createDispatchEvent({
owner: context.repo.owner,
repo: context.repo.repo,
event_type: 'i18n-sync',
client_payload: {
changed_files: process.env.CHANGED_FILES,
changes_base64: changesEmbedded ? changesBase64 : '',
changes_embedded: changesEmbedded,
sync_mode: 'incremental',
base_sha: process.env.BASE_SHA,
head_sha: process.env.HEAD_SHA,
},
})

View File

@@ -1,95 +0,0 @@
name: Run Full VDB Tests
on:
schedule:
- cron: '0 3 * * 1'
workflow_dispatch:
permissions:
contents: read
concurrency:
group: vdb-tests-full-${{ github.ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Full VDB Tests
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- "3.12"
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
cache-dependency-glob: api/uv.lock
- name: Check UV lockfile
run: uv lock --project api --check
- name: Install dependencies
run: uv sync --project api --dev
- name: Set up dotenvs
run: |
cp docker/.env.example docker/.env
cp docker/middleware.env.example docker/middleware.env
- name: Expose Service Ports
run: sh .github/workflows/expose_service_ports.sh
# - name: Set up Vector Store (TiDB)
# uses: hoverkraft-tech/compose-action@v2.0.2
# with:
# compose-file: docker/tidb/docker-compose.yaml
# services: |
# tidb
# tiflash
- name: Set up Full Vector Store Matrix
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
# - name: Check VDB Ready (TiDB)
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -1,18 +1,15 @@
name: Run VDB Smoke Tests
name: Run VDB Tests
on:
workflow_call:
permissions:
contents: read
concurrency:
group: vdb-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: VDB Smoke Tests
name: VDB Tests
runs-on: ubuntu-latest
strategy:
matrix:
@@ -61,18 +58,23 @@ jobs:
# tidb
# tiflash
- name: Set up Vector Stores for Smoke Coverage
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml
services: |
db_postgres
redis
weaviate
qdrant
couchbase-server
etcd
minio
milvus-standalone
pgvecto-rs
pgvector
chroma
elasticsearch
oceanbase
- name: setup test config
run: |
@@ -84,9 +86,4 @@ jobs:
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: |
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/pgvector \
api/tests/integration_tests/vdb/qdrant \
api/tests/integration_tests/vdb/weaviate
run: uv run --project api bash dev/pytest/pytest_vdb.sh

View File

@@ -27,6 +27,10 @@ jobs:
- name: Setup web dependencies
uses: ./.github/actions/setup-web
- name: Install E2E package dependencies
working-directory: ./e2e
run: vp install --frozen-lockfile
- name: Setup UV and Python
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
with:

View File

@@ -89,3 +89,34 @@ jobs:
flags: web
env:
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
web-build:
name: Web Build
runs-on: ubuntu-latest
defaults:
run:
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
.github/actions/setup-web/**
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: vp run build

4
.gitignore vendored
View File

@@ -212,8 +212,6 @@ api/.vscode
# pnpm
/.pnpm-store
/node_modules
.vite-hooks/_
# plugin migrate
plugins.jsonl
@@ -241,4 +239,4 @@ scripts/stress-test/reports/
*.local.md
# Code Agent Folder
.qoder/*
.qoder/*

View File

@@ -24,8 +24,8 @@ prepare-docker:
# Step 2: Prepare web environment
prepare-web:
@echo "🌐 Setting up web environment..."
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
@pnpm install
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
@cd web && pnpm install
@echo "✅ Web environment prepared (not started)"
# Step 3: Prepare API environment
@@ -93,7 +93,7 @@ test:
# Build Docker images
build-web:
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
docker build -t $(WEB_IMAGE):$(VERSION) ./web
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
build-api:

View File

@@ -40,8 +40,6 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/start-web
```
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
1. Set up your application by visiting `http://localhost:3000`.
1. Start the worker service (async and scheduler tasks, runs from `api`).

View File

@@ -7,16 +7,15 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
_IMAGE_EXTENSION_BASE: frozenset[str] = frozenset(("jpg", "jpeg", "png", "webp", "gif", "svg"))
_VIDEO_EXTENSION_BASE: frozenset[str] = frozenset(("mp4", "mov", "mpeg", "webm"))
_AUDIO_EXTENSION_BASE: frozenset[str] = frozenset(("mp3", "m4a", "wav", "amr", "mpga"))
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
IMAGE_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_IMAGE_EXTENSION_BASE))
VIDEO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_VIDEO_EXTENSION_BASE))
AUDIO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_AUDIO_EXTENSION_BASE))
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
(
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = {
"txt",
"markdown",
"md",
@@ -36,10 +35,11 @@ _UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
"pptx",
"xml",
"epub",
)
)
_DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
(
}
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.add("ppt")
else:
_doc_extensions = {
"txt",
"markdown",
"md",
@@ -53,17 +53,8 @@ _DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
"csv",
"vtt",
"properties",
)
)
_doc_extensions: set[str]
if dify_config.ETL_TYPE == "Unstructured":
_doc_extensions = set(_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE)
if dify_config.UNSTRUCTURED_API_URL:
_doc_extensions.add("ppt")
else:
_doc_extensions = set(_DEFAULT_DOCUMENT_EXTENSION_BASE)
DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_doc_extensions))
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
# console
COOKIE_NAME_ACCESS_TOKEN = "access_token"

View File

@@ -4,8 +4,8 @@ from urllib.parse import quote
from flask import Response
HTML_MIME_TYPES: frozenset[str] = frozenset(("text/html", "application/xhtml+xml"))
HTML_EXTENSIONS: frozenset[str] = frozenset(("html", "htm"))
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
HTML_EXTENSIONS = frozenset({"html", "htm"})
def _normalize_mime_type(mime_type: str | None) -> str:

View File

@@ -2,7 +2,7 @@ import flask_restx
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import delete, func, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
@@ -34,7 +34,7 @@ api_key_list_model = console_ns.model(
def _get_resource(resource_id, tenant_id, resource_model):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()

View File

@@ -9,7 +9,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.file import helpers as file_helpers
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest
from controllers.common.helpers import FileInfo
@@ -642,7 +642,7 @@ class AppCopyApi(Resource):
args = CopyAppPayload.model_validate(console_ns.payload or {})
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
result = import_service.import_app(
@@ -655,6 +655,7 @@ class AppCopyApi(Resource):
icon=args.icon,
icon_background=args.icon_background,
)
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:

View File

@@ -1,6 +1,6 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (
@@ -71,7 +71,7 @@ class AppImportApi(Resource):
args = AppImportPayload.model_validate(console_ns.payload)
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Import app
account = current_user
@@ -87,6 +87,7 @@ class AppImportApi(Resource):
icon_background=args.icon_background,
app_id=args.app_id,
)
session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
@@ -111,11 +112,12 @@ class AppImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@@ -132,7 +134,7 @@ class AppImportCheckDependenciesApi(Resource):
@marshal_with(app_import_check_dependencies_model)
@edit_permission_required
def get(self, app_model: App):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = AppDslService(session)
result = import_service.check_dependencies(app_model=app_model)

View File

@@ -2,7 +2,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource):
page_size = 100
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
rows = session.scalars(stmt).all()
return {

View File

@@ -9,8 +9,8 @@ from graphon.enums import NodeType
from graphon.file import File
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, ValidationError, field_validator
from sqlalchemy.orm import sessionmaker
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
@@ -268,18 +268,22 @@ class DraftWorkflowApi(Resource):
content_type = request.headers.get("Content-Type", "")
payload_data: dict[str, Any] | None = None
if "application/json" in content_type:
payload_data = request.get_json(silent=True)
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
elif "text/plain" in content_type:
try:
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
payload_data = json.loads(request.data.decode("utf-8"))
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
if not isinstance(payload_data, dict):
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
args = args_model.model_dump()
workflow_service = WorkflowService()
@@ -836,7 +840,7 @@ class PublishedWorkflowApi(Resource):
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
@@ -854,6 +858,8 @@ class PublishedWorkflowApi(Resource):
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
"created_at": workflow_created_at,
@@ -976,7 +982,7 @@ class PublishedAllWorkflowApi(Resource):
raise Forbidden()
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflows, has_more = workflow_service.get_all_published_workflow(
session=session,
app_model=app_model,
@@ -1066,7 +1072,7 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService()
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow = workflow_service.update_workflow(
session=session,
workflow_id=workflow_id,
@@ -1078,6 +1084,9 @@ class WorkflowByIdApi(Resource):
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@@ -1092,11 +1101,13 @@ class WorkflowByIdApi(Resource):
workflow_service = WorkflowService()
# Create a session and manage the transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
)
# Commit the transaction in the controller
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:

View File

@@ -5,7 +5,7 @@ from flask import request
from flask_restx import Resource, marshal_with
from graphon.enums import WorkflowExecutionStatus
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
@@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,
@@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
session=session,
app_model=app_model,

View File

@@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.console import console_ns
from controllers.console.app.error import (
@@ -244,7 +244,7 @@ class WorkflowVariableCollectionApi(Resource):
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -298,7 +298,7 @@ class NodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -465,7 +465,7 @@ class VariableResetApi(Resource):
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)

View File

@@ -4,7 +4,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
@@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
node_id = args.node_id
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Get webhook trigger for this app and node
webhook_trigger = (
session.query(WorkflowWebhookTrigger)
@@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Get all triggers for this app using select API
triggers = (
session.execute(
@@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource):
assert current_user.current_tenant_id is not None
trigger_id = args.trigger_id
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
# Find the trigger using select
trigger = session.execute(
select(AppTrigger).where(
@@ -153,6 +153,9 @@ class AppTriggerEnableApi(Resource):
# Update status based on enable_trigger boolean
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
session.commit()
session.refresh(trigger)
# Add computed icon field
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
if trigger.trigger_type == "trigger-plugin":

View File

@@ -36,7 +36,7 @@ class Subscription(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
BillingService.is_tenant_owner_or_admin(current_user)
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)

View File

@@ -31,7 +31,7 @@ class ComplianceApi(Resource):
@only_edition_cloud
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
ip_address = extract_remote_ip(request)
device_info = request.headers.get("User-Agent", "Unknown device")

View File

@@ -6,7 +6,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import get_or_create_model, register_schema_model
@@ -159,7 +159,7 @@ class DataSourceApi(Resource):
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
binding_id = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
).scalar_one_or_none()
@@ -211,7 +211,7 @@ class DataSourceNotionListApi(Resource):
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)

View File

@@ -3,7 +3,7 @@ import logging
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
@@ -85,7 +85,7 @@ class CustomizedPipelineTemplateApi(Resource):
@account_initialization_required
@enterprise_license_required
def post(self, template_id: str):
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)

View File

@@ -1,6 +1,6 @@
from flask_restx import Resource, marshal
from pydantic import BaseModel
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
import services
@@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource):
yaml_content=payload.yaml_content,
)
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
tenant_id=current_tenant_id,

View File

@@ -5,7 +5,7 @@ from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
from graphon.variables.types import SegmentType
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
@@ -96,7 +96,7 @@ class RagPipelineVariableCollectionApi(Resource):
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -143,7 +143,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
@marshal_with(workflow_draft_variable_list_model)
def get(self, pipeline: Pipeline, node_id: str):
validate_node_id(node_id)
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
@@ -289,7 +289,7 @@ class RagPipelineVariableResetApi(Resource):
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)

View File

@@ -1,7 +1,7 @@
from flask import request
from flask_restx import Resource, fields, marshal_with # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import get_or_create_model, register_schema_models
from controllers.console import console_ns
@@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource):
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Import app
account = current_user
@@ -80,6 +80,7 @@ class RagPipelineImportApi(Resource):
pipeline_id=payload.pipeline_id,
dataset_name=payload.name,
)
session.commit()
# Return appropriate status code based on result
status = result.status
@@ -101,11 +102,12 @@ class RagPipelineImportConfirmApi(Resource):
current_user, _ = current_account_with_tenant()
# Create service with session
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
# Confirm import
account = current_user
result = import_service.confirm_import(import_id=import_id, account=account)
session.commit()
# Return appropriate status code based on result
if result.status == ImportStatus.FAILED:
@@ -122,7 +124,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
@edit_permission_required
@marshal_with(pipeline_import_check_dependencies_model)
def get(self, pipeline: Pipeline):
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
import_service = RagPipelineDslService(session)
result = import_service.check_dependencies(pipeline=pipeline)
@@ -140,7 +142,7 @@ class RagPipelineExportApi(Resource):
# Add include_secret params
query = IncludeSecretQuery.model_validate(request.args.to_dict())
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
export_service = RagPipelineDslService(session)
result = export_service.export_rag_pipeline_dsl(
pipeline=pipeline, include_secret=query.include_secret == "true"

View File

@@ -5,8 +5,8 @@ from typing import Any, Literal, cast
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import sessionmaker
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
import services
@@ -186,14 +186,29 @@ class DraftRagPipelineApi(Resource):
if "application/json" in content_type:
payload_dict = console_ns.payload or {}
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
elif "text/plain" in content_type:
try:
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
except (ValueError, ValidationError):
data = json.loads(request.data.decode("utf-8"))
if "graph" not in data or "features" not in data:
raise ValueError("graph or features not found in data")
if not isinstance(data.get("graph"), dict):
raise ValueError("graph is not a dict")
payload_dict = {
"graph": data.get("graph"),
"features": data.get("features"),
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
else:
abort(415)
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
rag_pipeline_service = RagPipelineService()
try:
@@ -593,15 +608,19 @@ class PublishedRagPipelineApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, _ = current_account_with_tenant()
rag_pipeline_service = RagPipelineService()
workflow = rag_pipeline_service.publish_workflow(
session=db.session, # type: ignore[reportArgumentType,arg-type]
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
db.session.commit()
workflow_created_at = TimestampField().format(workflow.created_at)
with Session(db.engine) as session:
pipeline = session.merge(pipeline)
workflow = rag_pipeline_service.publish_workflow(
session=session,
pipeline=pipeline,
account=current_user,
)
pipeline.is_published = True
pipeline.workflow_id = workflow.id
session.add(pipeline)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
return {
"result": "success",
@@ -676,7 +695,7 @@ class PublishedAllRagPipelineApi(Resource):
raise Forbidden()
rag_pipeline_service = RagPipelineService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
session=session,
pipeline=pipeline,
@@ -748,7 +767,7 @@ class RagPipelineByIdApi(Resource):
rag_pipeline_service = RagPipelineService()
# Create a session and manage the transaction
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow = rag_pipeline_service.update_workflow(
session=session,
workflow_id=workflow_id,
@@ -760,6 +779,9 @@ class RagPipelineByIdApi(Resource):
if not workflow:
raise NotFound("Workflow not found")
# Commit the transaction in the controller
session.commit()
return workflow
@setup_required
@@ -776,13 +798,14 @@ class RagPipelineByIdApi(Resource):
workflow_service = WorkflowService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
try:
workflow_service.delete_workflow(
session=session,
workflow_id=workflow_id,
tenant_id=pipeline.tenant_id,
)
session.commit()
except WorkflowInUseError as e:
abort(400, description=str(e))
except DraftWorkflowDeletionError as e:

View File

@@ -2,7 +2,7 @@ from typing import Any
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
@@ -74,7 +74,7 @@ class ConversationListApi(InstalledAppResource):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@@ -2,7 +2,7 @@ from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from extensions.ext_database import db
@@ -24,7 +24,7 @@ def plugin_permission_required(
user = current_user
tenant_id = current_tenant_id
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
.where(

View File

@@ -8,7 +8,7 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
@@ -519,7 +519,7 @@ class EducationAutoCompleteApi(Resource):
@cloud_edition_billing_enabled
@marshal_with(data_fields)
def get(self):
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = EducationAutocompleteQuery.model_validate(payload)
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
@@ -562,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource):
user_email = current_user.email
else:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()

View File

@@ -99,7 +99,7 @@ class ModelProviderListApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserModelList.model_validate(payload)
model_provider_service = ModelProviderService()
@@ -118,7 +118,7 @@ class ModelProviderCredentialApi(Resource):
_, current_tenant_id = current_account_with_tenant()
tenant_id = current_tenant_id
# if credential_id is not provided, return current used credential
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = ParserCredentialId.model_validate(payload)
model_provider_service = ModelProviderService()

View File

@@ -287,10 +287,12 @@ class ModelProviderModelCredentialApi(Resource):
provider=provider,
)
else:
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
normalized_model_type = args.model_type.to_origin_model_type()
available_credentials = model_provider_service.get_provider_model_available_credentials(
tenant_id=tenant_id,
provider=provider,
model_type=args.model_type,
model_type=normalized_model_type,
model=args.model,
)

View File

@@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -1019,7 +1019,7 @@ class ToolProviderMCPApi(Resource):
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=payload.provider_id
@@ -1034,7 +1034,7 @@ class ToolProviderMCPApi(Resource):
)
# Step 3: Perform database update in a transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
@@ -1061,7 +1061,7 @@ class ToolProviderMCPApi(Resource):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
@@ -1079,7 +1079,7 @@ class ToolMCPAuthApi(Resource):
provider_id = payload.provider_id
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
@@ -1100,7 +1100,7 @@ class ToolMCPAuthApi(Resource):
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
@@ -1118,17 +1118,17 @@ class ToolMCPAuthApi(Resource):
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
@@ -1141,7 +1141,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
@@ -1155,7 +1155,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
@@ -1170,7 +1170,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
@@ -1188,7 +1188,7 @@ class ToolMCPCallbackApi(Resource):
authorization_code = query.code
# Create service instance for handle_callback
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session, session.begin():
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)

View File

@@ -5,7 +5,7 @@ from flask import make_response, redirect, request
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
@@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource):
assert user.current_tenant_id is not None
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
@@ -388,6 +388,7 @@ class TriggerSubscriptionDeleteApi(Resource):
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))

View File

@@ -155,7 +155,7 @@ class WorkspaceListApi(Resource):
@setup_required
@admin_required
def get(self):
payload = request.args.to_dict(flat=True)
payload = request.args.to_dict(flat=True) # type: ignore
args = WorkspaceListQuery.model_validate(payload)
stmt = select(Tenant).order_by(Tenant.created_at.desc())

View File

@@ -6,7 +6,7 @@ from flask import current_app, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.login import current_user
@@ -33,7 +33,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine) as session:
user_model = None
if is_anonymous:
@@ -56,7 +56,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
session_id=user_id,
)
session.add(user_model)
session.flush()
session.commit()
session.refresh(user_model)
except Exception:

View File

@@ -4,7 +4,7 @@ from flask import Response
from flask_restx import Resource
from graphon.variables.input_entities import VariableEntity
from pydantic import BaseModel, Field, ValidationError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
from controllers.mcp import mcp_ns
@@ -67,7 +67,7 @@ class MCPAppApi(Resource):
request_id: Union[int, str] | None = args.id
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
# Get MCP server and app
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
self._validate_server_status(mcp_server)
@@ -174,7 +174,6 @@ class MCPAppApi(Resource):
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
json_schema=variable.get("json_schema"),
)
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
@@ -189,7 +188,7 @@ class MCPAppApi(Resource):
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
"""Get end user - manages its own database session"""
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session, session.begin():
return (
session.query(EndUser)
.where(EndUser.tenant_id == tenant_id)
@@ -229,7 +228,9 @@ class MCPAppApi(Resource):
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
client_info = mcp_request.root.params.clientInfo
client_name = f"{client_info.name}@{client_info.version}"
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
# Commit the session before creating end user to avoid transaction conflicts
session.commit()
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)

View File

@@ -3,7 +3,7 @@ from typing import Any, Literal
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound
import services
@@ -116,7 +116,7 @@ class ConversationApi(Resource):
last_id = str(query_args.last_id) if query_args.last_id else None
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from pydantic import BaseModel, Field
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.common.schema import register_schema_models
@@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource):
# get paginate workflow app logs
workflow_app_service = WorkflowAppService()
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
session=session,
app_model=app_model,

View File

@@ -29,31 +29,6 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from services.summary_index_service import SummaryIndexService
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
"""Marshal a single segment and enrich it with summary content."""
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
"""Marshal multiple segments and enrich them with summary content (batch query)."""
segment_ids = [segment.id for segment in segments]
summaries: dict = {}
if segment_ids:
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
result = []
for segment in segments:
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
segment_dict["summary"] = summaries.get(segment.id)
result.append(segment_dict)
return result
class SegmentCreatePayload(BaseModel):
@@ -157,7 +132,7 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@@ -221,7 +196,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": _marshal_segments_with_summary(segments, dataset_id),
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@@ -321,7 +296,7 @@ class DatasetSegmentApi(DatasetApiResource):
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@@ -351,7 +326,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.route(

View File

@@ -2,7 +2,7 @@ from typing import Literal
from flask import request
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
@@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
query = ConversationListQuery.model_validate(raw_args)
try:
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,

View File

@@ -4,7 +4,7 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console.auth.error import (
@@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
else:
language = "en-US"
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
@@ -180,17 +180,18 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with sessionmaker(db.engine).begin() as session:
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt)
self._update_existing_account(account, password_hashed, salt, session)
else:
raise AuthenticationFailedError()
return {"result": "success"}
def _update_existing_account(self, account: Account, password_hashed, salt):
def _update_existing_account(self, account: Account, password_hashed, salt, session):
# Update existing account credentials
account.password = base64.b64encode(password_hashed).decode()
account.password_salt = base64.b64encode(salt).decode()
session.commit()

View File

@@ -6,7 +6,7 @@ from typing import Concatenate, ParamSpec, TypeVar
from flask import request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from constants import HEADER_NAME_APP_CODE
@@ -49,7 +49,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
with Session(db.engine, expire_on_commit=False) as session:
app_model = session.scalar(select(App).where(App.id == app_id))
site = session.scalar(select(Site).where(Site.code == app_code))
if not app_model:

View File

@@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.get(Workflow, workflow_id)
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(

View File

@@ -9,7 +9,6 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
from graphon.runtime import GraphRuntimeState, VariablePool
from graphon.variable_loader import VariableLoader
from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@@ -85,13 +84,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_id = None
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.get(Pipeline, app_config.app_id)
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@@ -214,10 +213,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
Get workflow
"""
# fetch workflow by workflow_id
workflow = db.session.scalar(
select(Workflow)
workflow = (
db.session.query(Workflow)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.limit(1)
.first()
)
# return workflow
@@ -298,8 +297,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
if isinstance(event, GraphRunFailedEvent):
if document_id and dataset_id:
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"

View File

@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(

View File

@@ -8,7 +8,6 @@ associates with the node span.
"""
import logging
from contextvars import Token
from dataclasses import dataclass
from typing import cast, final
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
@dataclass(slots=True)
class _NodeSpanContext:
span: "Span"
token: Token[context_api.Context]
token: object
@final

View File

@@ -153,7 +153,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.get(UploadFile, id)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
if not upload_file:
return None
@@ -171,7 +171,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = db.session.get(MessageFile, id)
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
# Check if message_file is not None
if message_file is not None:
@@ -185,7 +185,7 @@ class DatasourceFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id)
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
@@ -203,7 +203,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
return None, None

View File

@@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
@@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
ProviderModel.model_type == model_type.to_origin_model_type(),
)
return session.execute(stmt).scalar_one_or_none()
@@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
@@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
)
if exclude_id:
@@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = s.execute(stmt).scalar_one_or_none()
original_credentials = (
@@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
encrypted_config=json.dumps(credentials),
credential_name=credential_name,
)
@@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
credential_id=credential.id,
is_valid=True,
)
@@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_name=model,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
@@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type,
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
return session.execute(stmt).scalars().first()
@@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=True,
)
@@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
enabled=False,
)
@@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type,
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar() or 0
@@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=True,
)
@@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
model_setting = ProviderModelSetting(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
model_name=model,
load_balancing_enabled=False,
)

View File

@@ -17,7 +17,7 @@ class CSVSanitizer:
"""
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
FORMULA_CHARS = frozenset(("=", "+", "-", "@", "\t", "\r"))
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
@classmethod
def sanitize_value(cls, value: Any) -> str:

View File

@@ -10,7 +10,7 @@ from typing import Any
from flask import Flask, current_app
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import delete, func, select, update
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -78,7 +78,7 @@ class IndexingRunner:
continue
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
@@ -95,7 +95,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.get(Account, requeried_document.created_by)
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@@ -137,24 +137,23 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
for document_segment in document_segments:
db.session.delete(document_segment)
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id))
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
@@ -168,7 +167,7 @@ class IndexingRunner:
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
# transform
current_user = db.session.get(Account, requeried_document.created_by)
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
if not current_user:
raise ValueError("no current user found")
current_user.set_tenant_id(dataset.tenant_id)
@@ -208,18 +207,17 @@ class IndexingRunner:
return
# get dataset
dataset = db.session.get(Dataset, requeried_document.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
# get exist document_segment list and delete
document_segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == requeried_document.id,
)
).all()
document_segments = (
db.session.query(DocumentSegment)
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
.all()
)
documents = []
if document_segments:
@@ -291,7 +289,7 @@ class IndexingRunner:
embedding_model_instance = None
if dataset_id:
dataset = db.session.get(Dataset, dataset_id)
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("Dataset not found.")
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
@@ -654,26 +652,24 @@ class IndexingRunner:
@staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context():
dataset = db.session.get(Dataset, dataset_id)
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise ValueError("no dataset found")
keyword = Keyword(dataset)
keyword.create(documents)
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
document_ids = [document.metadata["doc_id"] for document in documents]
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
@@ -707,19 +703,17 @@ class IndexingRunner:
)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
db.session.execute(
update(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
)
.values(
status=SegmentStatus.COMPLETED,
enabled=True,
completed_at=naive_utc_now(),
)
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
DocumentSegment.status == SegmentStatus.INDEXING,
).update(
{
DocumentSegment.status: SegmentStatus.COMPLETED,
DocumentSegment.enabled: True,
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
@@ -740,17 +734,10 @@ class IndexingRunner:
"""
Update the document indexing status.
"""
count = (
db.session.scalar(
select(func.count())
.select_from(DatasetDocument)
.where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True)
)
or 0
)
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedError()
document = db.session.get(DatasetDocument, document_id)
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedError()
@@ -758,7 +745,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.commit()
@staticmethod
@@ -766,9 +753,7 @@ class IndexingRunner:
"""
Update the document segment by document id.
"""
db.session.execute(
update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params)
)
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()
def _transform(

View File

@@ -260,12 +260,4 @@ def convert_input_form_to_parameters(
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
elif item.type == VariableEntityType.CHECKBOX:
parameters[item.variable]["type"] = "boolean"
elif item.type == VariableEntityType.JSON_OBJECT:
parameters[item.variable]["type"] = "object"
if item.json_schema:
for key in ("properties", "required", "additionalProperties"):
if key in item.json_schema:
parameters[item.variable][key] = item.json_schema[key]
return parameters, required

View File

@@ -16,13 +16,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
DEPLOYMENT_ENVIRONMENT,
)
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import Link, SpanContext, TraceFlags
from configs import dify_config
@@ -51,10 +45,10 @@ class TraceClient:
self.endpoint = endpoint
self.resource = Resource(
attributes={
service_attributes.SERVICE_NAME: service_name,
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
HOST_NAME: socket.gethostname(),
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ACS_ARMS_SERVICE_FEATURE: "genai_app",
}
)

View File

@@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.semconv.attributes import exception_attributes
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.util.types import AttributeValue
@@ -38,7 +38,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@@ -135,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
if not exception_message:
exception_message = repr(error)
attributes: dict[str, AttributeValue] = {
exception_attributes.EXCEPTION_TYPE: exception_type,
exception_attributes.EXCEPTION_MESSAGE: exception_message,
exception_attributes.EXCEPTION_ESCAPED: False,
exception_attributes.EXCEPTION_STACKTRACE: error_string,
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
}
current_span.add_event(name="exception", attributes=attributes)
else:
@@ -470,7 +469,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
if trace_info.message_data and trace_info.message_data.message_metadata:
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
metadata_dict = json.loads(trace_info.message_data.message_metadata)
if model_params := metadata_dict.get("model_parameters"):
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)

View File

@@ -1,19 +1,9 @@
import logging
import os
import uuid
from datetime import UTC, datetime, timedelta
from datetime import datetime, timedelta
from graphon.enums import BuiltinNodeTypes
from langfuse import Langfuse
from langfuse.api import (
CreateGenerationBody,
CreateSpanBody,
IngestionEvent_GenerationCreate,
IngestionEvent_SpanCreate,
IngestionEvent_TraceCreate,
TraceBody,
)
from langfuse.api.commons.types.usage import Usage
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -406,61 +396,18 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_span(langfuse_span_data=name_generation_span_data)
def _make_event_id(self) -> str:
return str(uuid.uuid4())
def _now_iso(self) -> str:
return datetime.now(UTC).isoformat()
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
try:
body = TraceBody(
id=data.get("id"),
name=data.get("name"),
user_id=data.get("user_id"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
session_id=data.get("session_id"),
version=data.get("version"),
release=data.get("release"),
tags=data.get("tags"),
public=data.get("public"),
)
event = IngestionEvent_TraceCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.trace(**format_trace_data)
logger.debug("LangFuse Trace created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
try:
body = CreateSpanBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
input=data.get("input"),
output=data.get("output"),
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
)
event = IngestionEvent_SpanCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.span(**format_span_data)
logger.debug("LangFuse Span created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
@@ -471,45 +418,11 @@ class LangFuseDataTrace(BaseTraceInstance):
span.end(**format_span_data)
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
format_generation_data = (
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
)
try:
usage_data = data.pop("usage", None)
usage = None
if usage_data:
usage = Usage(
input=usage_data.get("input", 0) or 0,
output=usage_data.get("output", 0) or 0,
total=usage_data.get("total", 0) or 0,
unit=usage_data.get("unit"),
input_cost=usage_data.get("inputCost"),
output_cost=usage_data.get("outputCost"),
total_cost=usage_data.get("totalCost"),
)
body = CreateGenerationBody(
id=data.get("id"),
trace_id=data.get("trace_id"),
name=data.get("name"),
start_time=data.get("start_time"),
end_time=data.get("end_time"),
model=data.get("model"),
model_parameters=data.get("model_parameters"),
input=data.get("input"),
output=data.get("output"),
usage=usage,
metadata=data.get("metadata"),
level=data.get("level"),
status_message=data.get("status_message"),
parent_observation_id=data.get("parent_observation_id"),
version=data.get("version"),
completion_start_time=data.get("completion_start_time"),
)
event = IngestionEvent_GenerationCreate(
body=body,
id=self._make_event_id(),
timestamp=self._now_iso(),
)
self.langfuse_client.api.ingestion.batch(batch=[event])
self.langfuse_client.generation(**format_generation_data)
logger.debug("LangFuse Generation created successfully")
except Exception as e:
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
@@ -530,7 +443,7 @@ class LangFuseDataTrace(BaseTraceInstance):
def get_project_key(self):
try:
projects = self.langfuse_client.api.projects.get()
projects = self.langfuse_client.client.projects.get()
return projects.data[0].id
except Exception as e:
logger.debug("LangFuse get project key failed: %s", str(e))

View File

@@ -1,3 +1,4 @@
import json
import logging
import os
from datetime import datetime, timedelta
@@ -24,7 +25,6 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER
from extensions.ext_database import db
from models import EndUser
from models.workflow import WorkflowNodeExecutionModel
@@ -153,7 +153,7 @@ class MLflowDataTrace(BaseTraceInstance):
inputs = node.process_data # contains request URL
if not inputs:
inputs = JSON_DICT_ADAPTER.validate_json(node.inputs) if node.inputs else {}
inputs = json.loads(node.inputs) if node.inputs else {}
node_span = start_span_no_context(
name=node.title,
@@ -180,7 +180,7 @@ class MLflowDataTrace(BaseTraceInstance):
# End node span
finished_at = node.created_at + timedelta(seconds=node.elapsed_time)
outputs = JSON_DICT_ADAPTER.validate_json(node.outputs) if node.outputs else {}
outputs = json.loads(node.outputs) if node.outputs else {}
if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
outputs = self._parse_knowledge_retrieval_outputs(outputs)
elif node.node_type == BuiltinNodeTypes.LLM:
@@ -216,8 +216,8 @@ class MLflowDataTrace(BaseTraceInstance):
return {}, {}
try:
data = JSON_DICT_ADAPTER.validate_json(node.process_data)
except (ValueError, TypeError):
data = json.loads(node.process_data)
except (json.JSONDecodeError, TypeError):
return {}, {}
inputs = self._parse_prompts(data.get("prompts"))

View File

@@ -11,10 +11,8 @@ from uuid import UUID, uuid4
from cachetools import LRUCache
from flask import current_app
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from typing_extensions import TypedDict
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
@@ -35,7 +33,7 @@ from core.ops.entities.trace_entity import (
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import JSON_DICT_ADAPTER, get_message_data
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
@@ -52,14 +50,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class _AppTracingConfig(TypedDict, total=False):
enabled: bool
tracing_provider: str | None
_app_tracing_config_adapter: TypeAdapter[_AppTracingConfig] = TypeAdapter(_AppTracingConfig)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
@@ -478,7 +468,7 @@ class OpsTraceManager:
if app is None:
return None
app_ops_trace_config = _app_tracing_config_adapter.validate_json(app.tracing) if app.tracing else None
app_ops_trace_config = json.loads(app.tracing) if app.tracing else None
if app_ops_trace_config is None:
return None
if not app_ops_trace_config.get("enabled"):
@@ -570,7 +560,7 @@ class OpsTraceManager:
raise ValueError("App not found")
if not app.tracing:
return {"enabled": False, "tracing_provider": None}
app_trace_config = _app_tracing_config_adapter.validate_json(app.tracing)
app_trace_config = json.loads(app.tracing)
return app_trace_config
@staticmethod
@@ -646,6 +636,7 @@ class TraceTask:
carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading
large JSON blobs unnecessarily.
"""
import json
from models.workflow import WorkflowNodeExecutionModel
@@ -667,7 +658,7 @@ class TraceTask:
if not raw:
continue
try:
outputs = JSON_DICT_ADAPTER.validate_json(raw) if isinstance(raw, str) else raw
outputs = json.loads(raw) if isinstance(raw, str) else raw
except (ValueError, TypeError):
continue
if not isinstance(outputs, dict):
@@ -1429,7 +1420,7 @@ class TraceTask:
return {}
try:
metadata = JSON_DICT_ADAPTER.validate_json(message_data.message_metadata)
metadata = json.loads(message_data.message_metadata)
usage = metadata.get("usage", {})
time_to_first_token = usage.get("time_to_first_token")
time_to_generate = usage.get("time_to_generate")
@@ -1439,7 +1430,7 @@ class TraceTask:
"llm_streaming_time_to_generate": time_to_generate,
"is_streaming_request": time_to_first_token is not None,
}
except (ValueError, AttributeError):
except (json.JSONDecodeError, AttributeError):
return {}

View File

@@ -26,13 +26,7 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
DEPLOYMENT_ENVIRONMENT,
)
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanKind
from opentelemetry.util.types import AttributeValue
@@ -79,13 +73,13 @@ class TencentTraceClient:
self.resource = Resource(
attributes={
service_attributes.SERVICE_NAME: service_name,
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
HOST_NAME: socket.gethostname(),
"telemetry.sdk.language": "python",
"telemetry.sdk.name": "opentelemetry",
"telemetry.sdk.version": _get_opentelemetry_sdk_version(),
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
}
)
# Prepare gRPC endpoint/metadata

View File

@@ -3,14 +3,11 @@ from datetime import datetime
from typing import Any, Union
from urllib.parse import urlparse
from pydantic import TypeAdapter
from sqlalchemy import select
from models.engine import db
from models.model import Message
JSON_DICT_ADAPTER: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
def filter_none_values(data: dict[str, Any]) -> dict[str, Any]:
new_data = {}

View File

@@ -17,7 +17,6 @@ from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from core.plugin.endpoint.exc import EndpointSetupFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
from core.plugin.impl.exc import (
@@ -55,11 +54,6 @@ T = TypeVar("T", bound=(BaseModel | dict[str, Any] | list[Any] | bool | str))
logger = logging.getLogger(__name__)
_httpx_client: httpx.Client = get_pooled_http_client(
"plugin_daemon",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), trust_env=False),
)
class BasePluginClient:
def _request(
@@ -90,7 +84,7 @@ class BasePluginClient:
request_kwargs["content"] = prepared_data
try:
response = _httpx_client.request(**request_kwargs)
response = httpx.request(**request_kwargs)
except httpx.RequestError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
@@ -177,7 +171,7 @@ class BasePluginClient:
stream_kwargs["content"] = prepared_data
try:
with _httpx_client.stream(**stream_kwargs) as response:
with httpx.stream(**stream_kwargs) as response:
for raw_line in response.iter_lines():
if not raw_line:
continue

View File

@@ -306,7 +306,7 @@ class ProviderManager:
"""
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
@@ -324,7 +324,7 @@ class ProviderManager:
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model,
)
@@ -391,7 +391,7 @@ class ProviderManager:
raise ValueError(f"Model {model} does not exist.")
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
@@ -405,7 +405,7 @@ class ProviderManager:
# create default model
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type,
model_type=model_type.to_origin_model_type(),
provider_name=provider,
model_name=model,
)
@@ -626,8 +626,9 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
if provider_record.quota_type is not None:
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
for quota in configuration.quotas:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
@@ -640,7 +641,7 @@ class ProviderManager:
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=quota.quota_type, # type: ignore[arg-type]
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
@@ -822,7 +823,7 @@ class ProviderManager:
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
model_type=provider_model_record.model_type,
model_type=ModelType.value_of(provider_model_record.model_type),
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
@@ -920,8 +921,9 @@ class ProviderManager:
if provider_record.provider_type != ProviderType.SYSTEM:
continue
if provider_record.quota_type is not None:
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
@@ -1201,7 +1203,7 @@ class ProviderManager:
model_settings.append(
ModelSettings(
model=provider_model_setting.model_name,
model_type=provider_model_setting.model_type,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],

View File

@@ -122,6 +122,6 @@ class JiebaKeywordTableHandler:
results.add(token)
sub_tokens = re.findall(r"\w+", token)
if len(sub_tokens) > 1:
results.update({w for w in sub_tokens if w not in STOPWORDS})
results.update({w for w in sub_tokens if w not in list(STOPWORDS)})
return results

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,6 @@ from mysql.connector import Error as MySQLError
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -179,7 +178,9 @@ class AlibabaCloudMySQLVector(BaseVector):
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN ({placeholders})", ids)
docs = []
for record in cur:
metadata = parse_metadata_json(record["meta"])
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs
@@ -262,13 +263,15 @@ class AlibabaCloudMySQLVector(BaseVector):
# similarity = 1 / (1 + distance)
similarity = 1.0 / (1.0 + distance)
metadata = parse_metadata_json(record["meta"])
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = similarity
metadata["distance"] = distance
if similarity >= score_threshold:
docs.append(Document(page_content=record["text"], metadata=metadata))
except (ValueError, TypeError) as e:
except (ValueError, json.JSONDecodeError) as e:
logger.warning("Error processing search result: %s", e)
continue
@@ -303,7 +306,9 @@ class AlibabaCloudMySQLVector(BaseVector):
)
docs = []
for record in cur:
metadata = parse_metadata_json(record["meta"])
metadata = record["meta"]
if isinstance(metadata, str):
metadata = json.loads(metadata)
metadata["score"] = float(record["score"])
docs.append(Document(page_content=record["text"], metadata=metadata))
return docs

View File

@@ -8,7 +8,6 @@ _import_err_msg = (
"please run `pip install alibabacloud_gpdb20160503 alibabacloud_tea_openapi`"
)
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
@@ -258,7 +257,7 @@ class AnalyticdbVectorOpenAPI:
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = parse_metadata_json(match.metadata.get("metadata_"))
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),
@@ -295,7 +294,7 @@ class AnalyticdbVectorOpenAPI:
documents = []
for match in response.body.matches.match:
if match.score >= score_threshold:
metadata = parse_metadata_json(match.metadata.get("metadata_"))
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
page_content=match.metadata.get("page_content"),

View File

@@ -29,7 +29,6 @@ from pymochow.model.table import AnnSearch, BM25SearchRequest, HNSWSearchParams,
from configs import dify_config
from core.rag.datasource.vdb.field import Field as VDBField
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -174,9 +173,15 @@ class BaiduVector(BaseVector):
score = row.get("score", 0.0)
meta = row_data.get(VDBField.METADATA_KEY, {})
try:
meta = parse_metadata_json(meta)
except (ValueError, TypeError):
# Handle both JSON string and dict formats for backward compatibility
if isinstance(meta, str):
try:
import json
meta = json.loads(meta)
except (json.JSONDecodeError, TypeError):
meta = {}
elif not isinstance(meta, dict):
meta = {}
if score >= score_threshold:
@@ -195,11 +200,7 @@ class BaiduVector(BaseVector):
raise
def _init_client(self, config) -> MochowClient:
config = Configuration(
credentials=BceCredentials(config.account, config.api_key),
endpoint=config.endpoint,
connection_timeout_in_mills=config.connection_timeout_in_mills,
)
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
client = MochowClient(config)
return client

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from clickzetta.connector.v0.connection import Connection # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.embedding.embedding_base import Embeddings
@@ -357,19 +357,18 @@ class ClickzettaVector(BaseVector):
"""
try:
if raw_metadata:
# First parse may yield a string (double-encoded JSON) so use json.loads
first_pass = json.loads(raw_metadata)
metadata = json.loads(raw_metadata)
# Handle double-encoded JSON
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
if isinstance(metadata, str):
metadata = json.loads(metadata)
# Ensure we have a dict
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):
except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed for metadata")
# Fallback: extract document_id with regex
doc_id_match = re.search(r'"document_id":\s*"([^"]+)"', raw_metadata or "")
@@ -931,18 +930,17 @@ class ClickzettaVector(BaseVector):
# Parse metadata from JSON string (may be double-encoded)
try:
if row[2]:
# First parse may yield a string (double-encoded JSON)
first_pass = json.loads(row[2])
metadata = json.loads(row[2])
if isinstance(first_pass, str):
metadata = parse_metadata_json(first_pass)
elif isinstance(first_pass, dict):
metadata = first_pass
else:
# If result is a string, it's double-encoded JSON - parse again
if isinstance(metadata, str):
metadata = json.loads(metadata)
if not isinstance(metadata, dict):
metadata = {}
else:
metadata = {}
except (json.JSONDecodeError, ValueError, TypeError):
except (json.JSONDecodeError, TypeError):
logger.exception("JSON parsing failed")
# Fallback: extract document_id with regex

View File

@@ -1,24 +1,4 @@
from enum import StrEnum, auto
from typing import Any
from pydantic import TypeAdapter
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
def parse_metadata_json(raw: Any) -> dict[str, Any]:
"""Parse metadata from a JSON string or pass through an existing dict.
Many VDB drivers return metadata as either a JSON string or an already-
decoded dict depending on the column type and driver version.
"""
if raw is None or raw in ("", b""):
return {}
if isinstance(raw, dict):
return raw
if not isinstance(raw, (str, bytes, bytearray)):
return {}
return _metadata_adapter.validate_json(raw)
class Field(StrEnum):

View File

@@ -9,7 +9,6 @@ from psycopg import sql as psql
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -218,7 +217,8 @@ class HologresVector(BaseVector):
text = row[2]
meta = row[3]
meta = parse_metadata_json(meta)
if isinstance(meta, str):
meta = json.loads(meta)
# Convert distance to similarity score (consistent with pgvector)
score = 1 - distance
@@ -265,7 +265,8 @@ class HologresVector(BaseVector):
meta = row[2]
score = row[-1] # score is the last column from return_score
meta = parse_metadata_json(meta)
if isinstance(meta, str):
meta = json.loads(meta)
meta["score"] = score
docs.append(Document(page_content=text, metadata=meta))

View File

@@ -15,7 +15,6 @@ from typing import TYPE_CHECKING, Any
from configs import dify_config
from configs.middleware.vdb.iris_config import IrisVectorConfig
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -270,7 +269,7 @@ class IrisVector(BaseVector):
if len(row) >= 4:
text, meta_str, score = row[1], row[2], float(row[3])
if score >= score_threshold:
metadata = parse_metadata_json(meta_str)
metadata = json.loads(meta_str) if meta_str else {}
metadata["score"] = score
docs.append(Document(page_content=text, metadata=metadata))
return docs
@@ -385,7 +384,7 @@ class IrisVector(BaseVector):
meta_str = row[2]
score_value = row[3]
metadata = parse_metadata_json(meta_str)
metadata = json.loads(meta_str) if meta_str else {}
# Add score to metadata for hybrid search compatibility
score = float(score_value) if score_value is not None else 0.0
metadata["score"] = score

View File

@@ -9,7 +9,6 @@ from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -197,7 +196,11 @@ class MatrixoneVector(BaseVector):
docs = []
for result in results:
metadata = parse_metadata_json(result.metadata)
metadata = result.metadata
if isinstance(metadata, str):
import json
metadata = json.loads(metadata)
score = 1 - result.distance
if score >= score_threshold:
metadata["score"] = score

View File

@@ -4,7 +4,7 @@ import uuid
from enum import StrEnum
from typing import Any
from clickhouse_connect import get_client # type: ignore[import-untyped]
from clickhouse_connect import get_client
from pydantic import BaseModel
from configs import dify_config

View File

@@ -10,7 +10,6 @@ from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.exc import SQLAlchemyError
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -367,8 +366,8 @@ class OceanBaseVector(BaseVector):
# Parse metadata JSON
try:
metadata = parse_metadata_json(metadata_str)
except (ValueError, TypeError):
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else metadata_str
except json.JSONDecodeError:
logger.warning("Invalid JSON metadata: %s", metadata_str)
metadata = {}

View File

@@ -9,7 +9,7 @@ from pydantic import BaseModel, model_validator
from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -73,8 +73,7 @@ class TableStoreVector(BaseVector):
for item in table_result:
if item.is_ok and item.row:
kv = {k: v for k, v, _ in item.row.attribute_columns}
metadata = parse_metadata_json(kv[Field.METADATA_KEY])
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=metadata))
docs.append(Document(page_content=kv[Field.CONTENT_KEY], metadata=json.loads(kv[Field.METADATA_KEY])))
return docs
def get_type(self) -> str:
@@ -312,7 +311,7 @@ class TableStoreVector(BaseVector):
metadata_str = ots_column_map.get(Field.METADATA_KEY)
vector = json.loads(vector_str) if vector_str else None
metadata = parse_metadata_json(metadata_str)
metadata = json.loads(metadata_str) if metadata_str else {}
metadata["score"] = search_hit.score
@@ -372,7 +371,7 @@ class TableStoreVector(BaseVector):
ots_column_map[col[0]] = col[1]
metadata_str = ots_column_map.get(Field.METADATA_KEY)
metadata = parse_metadata_json(metadata_str)
metadata = json.loads(metadata_str) if metadata_str else {}
vector_str = ots_column_map.get(Field.VECTOR)
vector = json.loads(vector_str) if vector_str else None

View File

@@ -11,7 +11,6 @@ from tcvectordb.model import index as vdb_index # type: ignore
from tcvectordb.model.document import AnnSearch, Filter, KeywordSearch, WeightedRerank # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -287,10 +286,13 @@ class TencentVector(BaseVector):
return docs
for result in res[0]:
raw_meta = result.get(self.field_metadata)
# Compatible with version 1.1.3 and below: str means old driver.
score = (1 - result.get("score", 0.0)) if isinstance(raw_meta, str) else result.get("score", 0.0)
meta = parse_metadata_json(raw_meta)
meta = result.get(self.field_metadata)
if isinstance(meta, str):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)

View File

@@ -6,18 +6,11 @@ import httpx
from httpx import DigestAuth
from configs import dify_config
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import TidbAuthBinding
from models.enums import TidbAuthBindingStatus
# Reuse a pooled HTTP client for all TiDB Cloud requests to minimize connection churn
_tidb_http_client: httpx.Client = get_pooled_http_client(
"tidb:cloud",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
)
class TidbService:
@staticmethod
@@ -57,9 +50,7 @@ class TidbService:
"rootPassword": password,
}
response = _tidb_http_client.post(
f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key)
)
response = httpx.post(f"{api_url}/clusters", json=cluster_data, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@@ -93,9 +84,7 @@ class TidbService:
:return: The response from the API.
"""
response = _tidb_http_client.delete(
f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key)
)
response = httpx.delete(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@@ -114,7 +103,7 @@ class TidbService:
:return: The response from the API.
"""
response = _tidb_http_client.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
response = httpx.get(f"{api_url}/clusters/{cluster_id}", auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
return response.json()
@@ -139,7 +128,7 @@ class TidbService:
body = {"password": new_password, "builtinRole": "role_admin", "customRoles": []}
response = _tidb_http_client.patch(
response = httpx.patch(
f"{api_url}/clusters/{cluster_id}/sqlUsers/{account}",
json=body,
auth=DigestAuth(public_key, private_key),
@@ -173,9 +162,7 @@ class TidbService:
tidb_serverless_list_map = {item.cluster_id: item for item in tidb_serverless_list}
cluster_ids = [item.cluster_id for item in tidb_serverless_list]
params = {"clusterIds": cluster_ids, "view": "BASIC"}
response = _tidb_http_client.get(
f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key)
)
response = httpx.get(f"{api_url}/clusters:batchGet", params=params, auth=DigestAuth(public_key, private_key))
if response.status_code == 200:
response_data = response.json()
@@ -236,7 +223,7 @@ class TidbService:
clusters.append(cluster_data)
request_body = {"requests": clusters}
response = _tidb_http_client.post(
response = httpx.post(
f"{api_url}/clusters:batchCreate", json=request_body, auth=DigestAuth(public_key, private_key)
)

View File

@@ -9,7 +9,7 @@ from sqlalchemy import text as sql_text
from sqlalchemy.orm import Session, declarative_base
from configs import dify_config
from core.rag.datasource.vdb.field import Field, parse_metadata_json
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -228,7 +228,7 @@ class TiDBVector(BaseVector):
)
results = [(row[0], row[1], row[2]) for row in res]
for meta, text, distance in results:
metadata = parse_metadata_json(meta)
metadata = json.loads(meta)
metadata["score"] = 1 - distance
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@@ -15,7 +15,6 @@ from volcengine.viking_db import ( # type: ignore
from configs import dify_config
from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.field import parse_metadata_json
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
@@ -164,7 +163,7 @@ class VikingDBVector(BaseVector):
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = parse_metadata_json(metadata)
metadata = json.loads(metadata)
if metadata.get(key) == value:
ids.append(result.id)
return ids
@@ -190,7 +189,9 @@ class VikingDBVector(BaseVector):
docs = []
for result in results:
metadata = parse_metadata_json(result.fields.get(vdb_Field.METADATA_KEY))
metadata = result.fields.get(vdb_Field.METADATA_KEY)
if metadata is not None:
metadata = json.loads(metadata)
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY), metadata=metadata)

View File

@@ -35,7 +35,7 @@ class PdfExtractor(BaseExtractor):
"""
# Magic bytes for image format detection: (magic_bytes, extension, mime_type)
IMAGE_FORMATS: tuple[tuple[bytes, str, str], ...] = (
IMAGE_FORMATS = [
(b"\xff\xd8\xff", "jpg", "image/jpeg"),
(b"\x89PNG\r\n\x1a\n", "png", "image/png"),
(b"\x00\x00\x00\x0c\x6a\x50\x20\x20\x0d\x0a\x87\x0a", "jp2", "image/jp2"),
@@ -45,7 +45,7 @@ class PdfExtractor(BaseExtractor):
(b"MM\x00*", "tiff", "image/tiff"),
(b"II+\x00", "tiff", "image/tiff"),
(b"MM\x00+", "tiff", "image/tiff"),
)
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):

View File

@@ -35,10 +35,7 @@ class IndexProcessor:
if "parent_mode" in preview:
data.parent_mode = preview["parent_mode"]
# Different index processors return different preview shapes:
# - paragraph/parent-child processors: {"preview": [...]}
# - QA processor: {"qa_preview": [...]} (no "preview" key)
for item in preview.get("preview", []):
for item in preview["preview"]:
if "content" in item and "child_chunks" in item:
data.preview.append(
PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None)
@@ -47,10 +44,6 @@ class IndexProcessor:
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
elif "content" in item:
data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None))
for item in preview.get("qa_preview", []):
if "question" in item and "answer" in item:
data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"]))
return data
def index_and_clean(

View File

@@ -5,11 +5,11 @@ TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule"
TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin"
TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset(
(
{
TRIGGER_WEBHOOK_NODE_TYPE,
TRIGGER_SCHEDULE_NODE_TYPE,
TRIGGER_PLUGIN_NODE_TYPE,
)
}
)

View File

@@ -8,20 +8,24 @@ from pydantic import BaseModel, Field, field_validator
from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE
_WEBHOOK_HEADER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset((SegmentType.STRING,))
_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
}
)
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES: frozenset[SegmentType] = frozenset(
(
_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
)
}
)
_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES
_WEBHOOK_BODY_ALLOWED_TYPES: frozenset[SegmentType] = frozenset(
(
_WEBHOOK_BODY_ALLOWED_TYPES = frozenset(
{
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.BOOLEAN,
@@ -31,7 +35,7 @@ _WEBHOOK_BODY_ALLOWED_TYPES: frozenset[SegmentType] = frozenset(
SegmentType.ARRAY_BOOLEAN,
SegmentType.ARRAY_OBJECT,
SegmentType.FILE,
)
}
)

View File

@@ -27,10 +27,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
HOST_NAME,
)
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
@@ -117,8 +114,8 @@ class EnterpriseExporter:
resource = Resource(
attributes={
service_attributes.SERVICE_NAME: service_name,
HOST_NAME: socket.gethostname(),
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)

View File

@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
tenant_id=tenant_id,
provider_name=ModelProviderID(model_config.provider).provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=provider_configuration.system_configuration.current_quota_type,
quota_type=provider_configuration.system_configuration.current_quota_type.value,
),
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
additional_filters=_ProviderUpdateAdditionalFilters(

View File

@@ -6,24 +6,15 @@ def init_app(app: DifyApp):
if dify_config.SENTRY_DSN:
import sentry_sdk
from graphon.model_runtime.errors.invoke import InvokeRateLimitError
from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException
try:
from langfuse._utils import parse_error
_langfuse_error_response = parse_error.defaultErrorResponse
except (ImportError, AttributeError):
_langfuse_error_response = (
"Unexpected error occurred. Please check your request"
" and contact support: https://langfuse.com/support."
)
def before_send(event, hint):
if "exc_info" in hint:
_, exc_value, _ = hint["exc_info"]
if _langfuse_error_response in str(exc_value):
if parse_error.defaultErrorResponse in str(exc_value):
return None
return event
@@ -36,7 +27,7 @@ def init_app(app: DifyApp):
ValueError,
FileNotFoundError,
InvokeRateLimitError,
_langfuse_error_response,
parse_error.defaultErrorResponse,
],
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,

View File

@@ -20,7 +20,6 @@ from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from core.ops.utils import JSON_DICT_ADAPTER
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository
from extensions.logstore.aliyun_logstore import AliyunLogStore
@@ -49,10 +48,10 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut
"""
logger.debug("_dict_to_workflow_node_execution: data keys=%s", list(data.keys())[:5])
# Parse JSON fields
inputs = JSON_DICT_ADAPTER.validate_json(data.get("inputs") or "{}")
process_data = JSON_DICT_ADAPTER.validate_json(data.get("process_data") or "{}")
outputs = JSON_DICT_ADAPTER.validate_json(data.get("outputs") or "{}")
metadata = JSON_DICT_ADAPTER.validate_json(data.get("execution_metadata") or "{}")
inputs = json.loads(data.get("inputs", "{}"))
process_data = json.loads(data.get("process_data", "{}"))
outputs = json.loads(data.get("outputs", "{}"))
metadata = json.loads(data.get("execution_metadata", "{}"))
# Convert metadata to domain enum keys
domain_metadata = {}

View File

@@ -1,7 +1,5 @@
import contextlib
import logging
from collections.abc import Callable
from typing import Protocol, cast
import flask
from opentelemetry.instrumentation.celery import CeleryInstrumentor
@@ -23,38 +21,6 @@ from extensions.otel.runtime import is_celery_worker
logger = logging.getLogger(__name__)
class SupportsInstrument(Protocol):
def instrument(self, **kwargs: object) -> None: ...
class SupportsFlaskInstrumentor(Protocol):
def instrument_app(
self, app: DifyApp, response_hook: Callable[[Span, str, list], None] | None = None, **kwargs: object
) -> None: ...
# Some OpenTelemetry instrumentor constructors are typed loosely enough that
# pyrefly infers `NoneType`. Narrow the instances to just the methods we use
# while leaving runtime behavior unchanged.
def _new_celery_instrumentor() -> SupportsInstrument:
return cast(
SupportsInstrument,
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()),
)
def _new_httpx_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, HTTPXClientInstrumentor())
def _new_redis_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, RedisInstrumentor())
def _new_sqlalchemy_instrumentor() -> SupportsInstrument:
return cast(SupportsInstrument, SQLAlchemyInstrumentor())
class ExceptionLoggingHandler(logging.Handler):
"""
Handler that records exceptions to the current OpenTelemetry span.
@@ -131,7 +97,7 @@ def init_flask_instrumentor(app: DifyApp) -> None:
from opentelemetry.instrumentation.flask import FlaskInstrumentor
instrumentor = cast(SupportsFlaskInstrumentor, FlaskInstrumentor())
instrumentor = FlaskInstrumentor()
if dify_config.DEBUG:
logger.info("Initializing Flask instrumentor")
instrumentor.instrument_app(app, response_hook=response_hook)
@@ -140,21 +106,21 @@ def init_flask_instrumentor(app: DifyApp) -> None:
def init_sqlalchemy_instrumentor(app: DifyApp) -> None:
with app.app_context():
engines = list(app.extensions["sqlalchemy"].engines.values())
_new_sqlalchemy_instrumentor().instrument(enable_commenter=True, engines=engines)
SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
def init_redis_instrumentor() -> None:
_new_redis_instrumentor().instrument()
RedisInstrumentor().instrument()
def init_httpx_instrumentor() -> None:
_new_httpx_instrumentor().instrument()
HTTPXClientInstrumentor().instrument()
def init_instruments(app: DifyApp) -> None:
if not is_celery_worker():
init_flask_instrumentor(app)
_new_celery_instrumentor().instrument()
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
instrument_exception_logging()
init_sqlalchemy_instrumentor(app)

View File

@@ -15,12 +15,8 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import TypeAdapter
logger = logging.getLogger(__name__)
_metadata_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class FileStatus(StrEnum):
"""File status enumeration"""
@@ -459,8 +455,8 @@ class FileLifecycleManager:
try:
if self._storage.exists(self._metadata_file):
metadata_content = self._storage.load_once(self._metadata_file)
result = _metadata_adapter.validate_json(metadata_content)
return result or {}
result = json.loads(metadata_content.decode("utf-8"))
return dict(result) if result else {}
else:
return {}
except Exception as e:

View File

@@ -1,16 +1,13 @@
import base64
import io
import json
from collections.abc import Generator
from typing import Any
from google.cloud import storage as google_cloud_storage # type: ignore
from pydantic import TypeAdapter
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
_service_account_adapter: TypeAdapter[dict[str, Any]] = TypeAdapter(dict[str, Any])
class GoogleCloudStorage(BaseStorage):
"""Implementation for Google Cloud storage."""
@@ -24,7 +21,7 @@ class GoogleCloudStorage(BaseStorage):
if service_account_json_str:
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object
service_account_obj = _service_account_adapter.validate_json(service_account_json)
service_account_obj = json.loads(service_account_json)
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
else:
self.client = google_cloud_storage.Client()

View File

@@ -1,12 +1,9 @@
from collections.abc import Collection
def convert_to_lower_and_upper_set(inputs: Collection[str]) -> set[str]:
def convert_to_lower_and_upper_set(inputs: list[str] | set[str]) -> set[str]:
"""
Convert a collection of strings to a set containing both lower and upper case versions of each string.
Convert a list or set of strings to a set containing both lower and upper case versions of each string.
Args:
inputs (Collection[str]): A collection of strings to be converted.
inputs (list[str] | set[str]): A list or set of strings to be converted.
Returns:
set[str]: A set containing both lower and upper case versions of each string.

View File

@@ -7,8 +7,6 @@ from typing import NotRequired
import httpx
from pydantic import TypeAdapter, ValidationError
from core.helper.http_client_pooling import get_pooled_http_client
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
@@ -22,12 +20,6 @@ JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
# Reuse a pooled httpx.Client for OAuth flows (public endpoints, no SSRF proxy).
_http_client: httpx.Client = get_pooled_http_client(
"oauth:default",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
)
class AccessTokenResponse(TypedDict, total=False):
access_token: str
@@ -123,7 +115,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = _http_client.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
@@ -135,7 +127,7 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"token {token}"}
response = _http_client.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
@@ -155,7 +147,7 @@ class GitHubOAuth(OAuth):
Returns an empty string when no usable email is found.
"""
try:
email_response = _http_client.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers)
email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers)
email_response.raise_for_status()
email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
except (httpx.HTTPStatusError, ValidationError):
@@ -212,7 +204,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = _http_client.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
@@ -224,7 +216,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"Bearer {token}"}
response = _http_client.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return _json_object(response)

View File

@@ -7,7 +7,6 @@ from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from core.helper.http_client_pooling import get_pooled_http_client
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@@ -39,13 +38,6 @@ NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
# Reuse a small pooled client for OAuth data source flows.
_http_client: httpx.Client = get_pooled_http_client(
"oauth:notion",
lambda: httpx.Client(limits=httpx.Limits(max_keepalive_connections=50, max_connections=100)),
)
class OAuthDataSource:
client_id: str
client_secret: str
@@ -83,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = _http_client.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@@ -276,7 +268,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28",
}
response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
@@ -291,7 +283,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = _http_client.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
@@ -307,7 +299,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = _http_client.get(url=self._NOTION_BOT_USER, headers=headers)
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
@@ -331,7 +323,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = _http_client.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))

Some files were not shown because too many files have changed in this diff Show More