mirror of
https://github.com/langgenius/dify.git
synced 2026-04-04 07:12:39 +00:00
Compare commits
103 Commits
codex/forc
...
codex/dify
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
142f94e27a | ||
|
|
608958de1c | ||
|
|
7eb632eb34 | ||
|
|
33d4fd357c | ||
|
|
e55bd61c17 | ||
|
|
f2fc213d52 | ||
|
|
f814579ed2 | ||
|
|
71d299d0d3 | ||
|
|
e178451d04 | ||
|
|
9a6222f245 | ||
|
|
affe5ed30b | ||
|
|
4cc5401d7e | ||
|
|
36e840cd87 | ||
|
|
985b41c40b | ||
|
|
2e29ac2829 | ||
|
|
dbfb474eab | ||
|
|
d243de26ec | ||
|
|
894826771a | ||
|
|
a1bd929b3c | ||
|
|
ffb9ee3e36 | ||
|
|
485586f49a | ||
|
|
a3386da5d6 | ||
|
|
318a3d0308 | ||
|
|
5bafb163cc | ||
|
|
52b1bc5b09 | ||
|
|
1873b22e96 | ||
|
|
9a8c853a2e | ||
|
|
e54383d0fe | ||
|
|
43c48ba4d7 | ||
|
|
8f9dbf269e | ||
|
|
cb9ee5903a | ||
|
|
cd406d2794 | ||
|
|
993a301468 | ||
|
|
399d3f8da5 | ||
|
|
f9d9ad7a38 | ||
|
|
2d29345f26 | ||
|
|
725f9e3dc4 | ||
|
|
4e1d060439 | ||
|
|
391007d02e | ||
|
|
e41965061c | ||
|
|
2b9eb06555 | ||
|
|
31f7752ba9 | ||
|
|
b23ea0397a | ||
|
|
c51cd42cb4 | ||
|
|
09ee8ea1f5 | ||
|
|
beda78e911 | ||
|
|
42d7623cc6 | ||
|
|
4bd388669a | ||
|
|
324b47507c | ||
|
|
d2baacdd4b | ||
|
|
57f358a96b | ||
|
|
19530e880a | ||
|
|
dbdbb098d5 | ||
|
|
2c8b47ce44 | ||
|
|
cf50d7c7b5 | ||
|
|
d9a0665b2c | ||
|
|
b818cc0766 | ||
|
|
90f94be2b3 | ||
|
|
24111facdd | ||
|
|
424d34a9c0 | ||
|
|
fbd2d31624 | ||
|
|
b54a0dc1e4 | ||
|
|
f27d669f87 | ||
|
|
fcf04629d3 | ||
|
|
6b0c6d0cde | ||
|
|
1063e021f2 | ||
|
|
303f548408 | ||
|
|
cc68f0e640 | ||
|
|
9b7b432e08 | ||
|
|
88863609e9 | ||
|
|
adc6c6c13b | ||
|
|
2de818530b | ||
|
|
7e4754392d | ||
|
|
01c857a67a | ||
|
|
2c2cc72150 | ||
|
|
f7b78b08fd | ||
|
|
f0e6f11c1c | ||
|
|
a19243068b | ||
|
|
323c51e095 | ||
|
|
bbc3f90928 | ||
|
|
1344c3b280 | ||
|
|
5897b28355 | ||
|
|
15aa8071f8 | ||
|
|
097095a69b | ||
|
|
daebe26089 | ||
|
|
c58170f5b8 | ||
|
|
3a7885819d | ||
|
|
5fc4dfaf7b | ||
|
|
953bcc33b1 | ||
|
|
bc14ad6a8f | ||
|
|
cc89b57c1f | ||
|
|
623c8ae803 | ||
|
|
dede190be2 | ||
|
|
a1513f06c3 | ||
|
|
3c7180bfd5 | ||
|
|
51f6ca2bed | ||
|
|
ae9a16a397 | ||
|
|
52a4bea88f | ||
|
|
1aaba80211 | ||
|
|
944db46d4f | ||
|
|
456684dfc3 | ||
|
|
40fa0f365c | ||
|
|
2cb71ad443 |
@@ -64,7 +64,7 @@ export const useUpdateAccessMode = () => {
|
||||
|
||||
// Component only adds UI behavior.
|
||||
updateAccessMode({ appId, mode }, {
|
||||
onSuccess: () => Toast.notify({ type: 'success', message: '...' }),
|
||||
onSuccess: () => toast.success('...'),
|
||||
})
|
||||
|
||||
// Avoid putting invalidation knowledge in the component.
|
||||
@@ -114,10 +114,7 @@ try {
|
||||
router.push(`/orders/${order.id}`)
|
||||
}
|
||||
catch (error) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
toast.error(error instanceof Error ? error.message : 'Unknown error')
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
1
.github/actions/setup-web/action.yml
vendored
1
.github/actions/setup-web/action.yml
vendored
@@ -6,7 +6,6 @@ runs:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
run-install: true
|
||||
|
||||
11
.github/workflows/api-tests.yml
vendored
11
.github/workflows/api-tests.yml
vendored
@@ -9,9 +9,6 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: api-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@@ -38,7 +35,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -87,7 +84,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -159,7 +156,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -206,7 +203,7 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
files: ./coverage.xml
|
||||
disable_search: true
|
||||
|
||||
9
.github/workflows/autofix.yml
vendored
9
.github/workflows/autofix.yml
vendored
@@ -10,9 +10,6 @@ on:
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
jobs:
|
||||
autofix:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
@@ -42,6 +39,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
@@ -55,7 +56,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
|
||||
- if: github.event_name != 'merge_group'
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
30
.github/workflows/build-push.yml
vendored
30
.github/workflows/build-push.yml
vendored
@@ -24,27 +24,39 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "build-api-amd64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@@ -58,9 +70,6 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -74,7 +83,8 @@ jobs:
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -93,7 +103,7 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
7
.github/workflows/db-migration-test.yml
vendored
7
.github/workflows/db-migration-test.yml
vendored
@@ -3,9 +3,6 @@ name: DB Migration Test
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: db-migration-test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
@@ -22,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -72,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
30
.github/workflows/docker-build.yml
vendored
30
.github/workflows/docker-build.yml
vendored
@@ -6,7 +6,12 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
@@ -14,26 +19,31 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -41,8 +51,8 @@ jobs:
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
file: "${{ matrix.file }}"
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
11
.github/workflows/main-ci.yml
vendored
11
.github/workflows/main-ci.yml
vendored
@@ -16,9 +16,6 @@ permissions:
|
||||
checks: write
|
||||
statuses: write
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
|
||||
concurrency:
|
||||
group: main-ci-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
@@ -68,6 +65,10 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
e2e:
|
||||
@@ -76,6 +77,10 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
|
||||
15
.github/workflows/pyrefly-diff.yml
vendored
15
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -50,6 +50,17 @@ jobs:
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Check if line counts match
|
||||
id: line_count_check
|
||||
run: |
|
||||
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
|
||||
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
|
||||
if [ "$base_lines" -eq "$pr_lines" ]; then
|
||||
echo "same=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "same=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
@@ -63,7 +74,7 @@ jobs:
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
10
.github/workflows/style.yml
vendored
10
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -77,6 +77,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
@@ -90,9 +94,9 @@ jobs:
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
3
.github/workflows/tool-test-sdks.yaml
vendored
3
.github/workflows/tool-test-sdks.yaml
vendored
@@ -6,6 +6,9 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- sdks/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
297
.github/workflows/translate-i18n-claude.yml
vendored
297
.github/workflows/translate-i18n-claude.yml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Translate i18n Files with Claude Code
|
||||
|
||||
# Note: claude-code-action doesn't support push events directly.
|
||||
# Push events are bridged by trigger-i18n-sync.yml via repository_dispatch.
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
repository_dispatch:
|
||||
types: [i18n-sync]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
files:
|
||||
@@ -30,7 +30,7 @@ permissions:
|
||||
|
||||
concurrency:
|
||||
group: translate-i18n-${{ github.event_name }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.event_name == 'push' }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
translate:
|
||||
@@ -67,19 +67,113 @@ jobs:
|
||||
}
|
||||
" web/i18n-config/languages.ts | sed 's/[[:space:]]*$//')
|
||||
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BASE_SHA="${{ github.event.before }}"
|
||||
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
fi
|
||||
HEAD_SHA="${{ github.sha }}"
|
||||
if [ -n "$BASE_SHA" ]; then
|
||||
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
else
|
||||
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
generate_changes_json() {
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const currentJson = readCurrentJson(fileStem)
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = currentJson || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: currentJson === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
}
|
||||
|
||||
if [ "${{ github.event_name }}" = "repository_dispatch" ]; then
|
||||
BASE_SHA="${{ github.event.client_payload.base_sha }}"
|
||||
HEAD_SHA="${{ github.event.client_payload.head_sha }}"
|
||||
CHANGED_FILES="${{ github.event.client_payload.changed_files }}"
|
||||
TARGET_LANGS="$DEFAULT_TARGET_LANGS"
|
||||
SYNC_MODE="incremental"
|
||||
SYNC_MODE="${{ github.event.client_payload.sync_mode || 'incremental' }}"
|
||||
|
||||
if [ -n "${{ github.event.client_payload.changes_base64 }}" ]; then
|
||||
printf '%s' '${{ github.event.client_payload.changes_base64 }}' | base64 -d > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="embedded"
|
||||
elif [ -n "$BASE_SHA" ] && [ -n "$CHANGED_FILES" ]; then
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
generate_changes_json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="recomputed"
|
||||
else
|
||||
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="false"
|
||||
CHANGES_SOURCE="unavailable"
|
||||
fi
|
||||
else
|
||||
BASE_SHA=""
|
||||
HEAD_SHA=$(git rev-parse HEAD)
|
||||
@@ -104,6 +198,17 @@ jobs:
|
||||
else
|
||||
CHANGED_FILES=""
|
||||
fi
|
||||
|
||||
if [ "$SYNC_MODE" = "incremental" ] && [ -n "$CHANGED_FILES" ]; then
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
generate_changes_json
|
||||
CHANGES_AVAILABLE="true"
|
||||
CHANGES_SOURCE="local"
|
||||
else
|
||||
printf '%s' '{"baseSha":"","headSha":"","files":[],"changes":{}}' > /tmp/i18n-changes.json
|
||||
CHANGES_AVAILABLE="false"
|
||||
CHANGES_SOURCE="unavailable"
|
||||
fi
|
||||
fi
|
||||
|
||||
FILE_ARGS=""
|
||||
@@ -123,6 +228,8 @@ jobs:
|
||||
echo "CHANGED_FILES=$CHANGED_FILES"
|
||||
echo "TARGET_LANGS=$TARGET_LANGS"
|
||||
echo "SYNC_MODE=$SYNC_MODE"
|
||||
echo "CHANGES_AVAILABLE=$CHANGES_AVAILABLE"
|
||||
echo "CHANGES_SOURCE=$CHANGES_SOURCE"
|
||||
echo "FILE_ARGS=$FILE_ARGS"
|
||||
echo "LANG_ARGS=$LANG_ARGS"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
@@ -141,7 +248,7 @@ jobs:
|
||||
show_full_output: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
prompt: |
|
||||
You are the i18n sync agent for the Dify repository.
|
||||
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result.
|
||||
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`.
|
||||
|
||||
Use absolute paths at all times:
|
||||
- Repo root: `${{ github.workspace }}`
|
||||
@@ -156,12 +263,15 @@ jobs:
|
||||
- Head SHA: `${{ steps.context.outputs.HEAD_SHA }}`
|
||||
- Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}`
|
||||
- Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Structured change set available: `${{ steps.context.outputs.CHANGES_AVAILABLE }}`
|
||||
- Structured change set source: `${{ steps.context.outputs.CHANGES_SOURCE }}`
|
||||
- Structured change set file: `/tmp/i18n-changes.json`
|
||||
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `git`, `gh`, `pnpm`, and `date`.
|
||||
- Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution.
|
||||
- Use Bash only for `pnpm`.
|
||||
- Do not use Bash for `git`, `gh`, or branch management.
|
||||
|
||||
Required execution plan:
|
||||
1. Resolve target languages.
|
||||
@@ -172,27 +282,25 @@ jobs:
|
||||
- Only process the resolved target languages, never `en-US`.
|
||||
- Do not touch unrelated i18n files.
|
||||
- Do not modify `${{ github.workspace }}/web/i18n/en-US/`.
|
||||
3. Detect English changes per file.
|
||||
- Read the current English JSON file for each file in scope.
|
||||
- If sync mode is `incremental` and `Base SHA` is not empty, run:
|
||||
`git -C ${{ github.workspace }} show <Base SHA>:web/i18n/en-US/<file>.json`
|
||||
- If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync.
|
||||
- If the file did not exist at Base SHA, treat all current keys as ADD.
|
||||
- Compare previous and current English JSON to identify:
|
||||
- ADD: key only in current
|
||||
- UPDATE: key exists in both and the English value changed
|
||||
- DELETE: key only in previous
|
||||
- Do not rely on a truncated diff file.
|
||||
3. Resolve source changes.
|
||||
- If `Structured change set available` is `true`, read `/tmp/i18n-changes.json` and use it as the source of truth for file-level and key-level changes.
|
||||
- For each file entry:
|
||||
- `added` contains new English keys that need translations.
|
||||
- `updated` contains stale keys whose English source changed; re-translate using the `after` value.
|
||||
- `deleted` contains keys that should be removed from locale files.
|
||||
- `fileDeleted: true` means the English file no longer exists; remove the matching locale file if present.
|
||||
- Read the current English JSON file for any file that still exists so wording, placeholders, and surrounding terminology stay accurate.
|
||||
- If `Structured change set available` is `false`, treat this as a scoped full sync and use the current English files plus scoped checks as the source of truth.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
- If `fileDeleted` is `true`, remove the locale file if it exists and skip the rest of that file.
|
||||
- If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed.
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
@@ -200,14 +308,119 @@ jobs:
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Create a PR only when there are changes in `web/i18n/`.
|
||||
- Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/`
|
||||
- Create branch `chore/i18n-sync-<timestamp>`
|
||||
- Commit message: `chore(i18n): sync translations with en-US`
|
||||
- Push the branch and open a PR against `main`
|
||||
- PR title: `chore(i18n): sync translations with en-US`
|
||||
- PR body: summarize files, languages, sync mode, and verification commands
|
||||
8. If there are no translation changes after verification, do not create a branch, commit, or PR.
|
||||
7. Stop after the scoped locale files are updated and verification passes.
|
||||
- Do not create branches, commits, or pull requests.
|
||||
claude_args: |
|
||||
--max-turns 80
|
||||
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
|
||||
--max-turns 120
|
||||
--allowedTools "Read,Write,Edit,Bash(pnpm *),Bash(pnpm:*),Glob,Grep"
|
||||
|
||||
- name: Prepare branch metadata
|
||||
id: pr_meta
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
shell: bash
|
||||
run: |
|
||||
if [ -z "$(git -C "${{ github.workspace }}" status --porcelain -- web/i18n/)" ]; then
|
||||
echo "has_changes=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
SCOPE_HASH=$(printf '%s|%s|%s' "${{ steps.context.outputs.CHANGED_FILES }}" "${{ steps.context.outputs.TARGET_LANGS }}" "${{ steps.context.outputs.SYNC_MODE }}" | sha256sum | cut -c1-8)
|
||||
HEAD_SHORT=$(printf '%s' "${{ steps.context.outputs.HEAD_SHA }}" | cut -c1-12)
|
||||
BRANCH_NAME="chore/i18n-sync-${HEAD_SHORT}-${SCOPE_HASH}"
|
||||
|
||||
{
|
||||
echo "has_changes=true"
|
||||
echo "branch_name=$BRANCH_NAME"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Commit translation changes
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
git -C "${{ github.workspace }}" checkout -B "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
git -C "${{ github.workspace }}" add web/i18n/
|
||||
git -C "${{ github.workspace }}" commit -m "chore(i18n): sync translations with en-US"
|
||||
|
||||
- name: Push translation branch
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
if git -C "${{ github.workspace }}" ls-remote --exit-code --heads origin "${{ steps.pr_meta.outputs.branch_name }}" >/dev/null 2>&1; then
|
||||
git -C "${{ github.workspace }}" push --force-with-lease origin "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
else
|
||||
git -C "${{ github.workspace }}" push --set-upstream origin "${{ steps.pr_meta.outputs.branch_name }}"
|
||||
fi
|
||||
|
||||
- name: Create or update translation PR
|
||||
if: steps.pr_meta.outputs.has_changes == 'true'
|
||||
env:
|
||||
BRANCH_NAME: ${{ steps.pr_meta.outputs.branch_name }}
|
||||
FILES_IN_SCOPE: ${{ steps.context.outputs.CHANGED_FILES }}
|
||||
TARGET_LANGS: ${{ steps.context.outputs.TARGET_LANGS }}
|
||||
SYNC_MODE: ${{ steps.context.outputs.SYNC_MODE }}
|
||||
CHANGES_SOURCE: ${{ steps.context.outputs.CHANGES_SOURCE }}
|
||||
BASE_SHA: ${{ steps.context.outputs.BASE_SHA }}
|
||||
HEAD_SHA: ${{ steps.context.outputs.HEAD_SHA }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
shell: bash
|
||||
run: |
|
||||
PR_BODY_FILE=/tmp/i18n-pr-body.md
|
||||
LANG_COUNT=$(printf '%s\n' "$TARGET_LANGS" | wc -w | tr -d ' ')
|
||||
if [ "$LANG_COUNT" = "0" ]; then
|
||||
LANG_COUNT="0"
|
||||
fi
|
||||
export LANG_COUNT
|
||||
|
||||
node <<'NODE' > "$PR_BODY_FILE"
|
||||
const fs = require('node:fs')
|
||||
|
||||
const changesPath = '/tmp/i18n-changes.json'
|
||||
const changes = fs.existsSync(changesPath)
|
||||
? JSON.parse(fs.readFileSync(changesPath, 'utf8'))
|
||||
: { changes: {} }
|
||||
|
||||
const filesInScope = (process.env.FILES_IN_SCOPE || '').split(/\s+/).filter(Boolean)
|
||||
const lines = [
|
||||
'## Summary',
|
||||
'',
|
||||
`- **Files synced**: \`${process.env.FILES_IN_SCOPE || '<none>'}\``,
|
||||
`- **Languages updated**: ${process.env.TARGET_LANGS || '<none>'} (${process.env.LANG_COUNT} languages)`,
|
||||
`- **Sync mode**: ${process.env.SYNC_MODE}${process.env.BASE_SHA ? ` (base: \`${process.env.BASE_SHA.slice(0, 10)}\`, head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)` : ` (head: \`${process.env.HEAD_SHA.slice(0, 10)}\`)`}`,
|
||||
'',
|
||||
'### Key changes',
|
||||
]
|
||||
|
||||
for (const fileName of filesInScope) {
|
||||
const fileChange = changes.changes?.[fileName] || { added: {}, updated: {}, deleted: [], fileDeleted: false }
|
||||
const addedKeys = Object.keys(fileChange.added || {})
|
||||
const updatedKeys = Object.keys(fileChange.updated || {})
|
||||
const deletedKeys = fileChange.deleted || []
|
||||
lines.push(`- \`${fileName}\`: +${addedKeys.length} / ~${updatedKeys.length} / -${deletedKeys.length}${fileChange.fileDeleted ? ' (file deleted in en-US)' : ''}`)
|
||||
}
|
||||
|
||||
lines.push(
|
||||
'',
|
||||
'## Verification',
|
||||
'',
|
||||
`- \`pnpm --dir web run i18n:check --file ${process.env.FILES_IN_SCOPE} --lang ${process.env.TARGET_LANGS}\``,
|
||||
`- \`pnpm --dir web lint:fix --quiet -- <edited i18n files>\``,
|
||||
'',
|
||||
'## Notes',
|
||||
'',
|
||||
'- This PR was generated from structured en-US key changes produced by `trigger-i18n-sync.yml`.',
|
||||
`- Structured change source: ${process.env.CHANGES_SOURCE || 'unknown'}.`,
|
||||
'- Branch name is deterministic for the head SHA and scope, so reruns update the same PR instead of opening duplicates.',
|
||||
'',
|
||||
'🤖 Generated with [Claude Code](https://claude.com/claude-code)'
|
||||
)
|
||||
|
||||
process.stdout.write(lines.join('\n'))
|
||||
NODE
|
||||
|
||||
EXISTING_PR_NUMBER=$(gh pr list --repo "$REPO_NAME" --head "$BRANCH_NAME" --state open --json number --jq '.[0].number')
|
||||
|
||||
if [ -n "$EXISTING_PR_NUMBER" ] && [ "$EXISTING_PR_NUMBER" != "null" ]; then
|
||||
gh pr edit "$EXISTING_PR_NUMBER" --repo "$REPO_NAME" --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
|
||||
else
|
||||
gh pr create --repo "$REPO_NAME" --head "$BRANCH_NAME" --base main --title "chore(i18n): sync translations with en-US" --body-file "$PR_BODY_FILE"
|
||||
fi
|
||||
|
||||
171
.github/workflows/trigger-i18n-sync.yml
vendored
Normal file
171
.github/workflows/trigger-i18n-sync.yml
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
name: Trigger i18n Sync on Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
concurrency:
|
||||
group: trigger-i18n-sync-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Detect changed files and build structured change set
|
||||
id: detect
|
||||
shell: bash
|
||||
run: |
|
||||
BASE_SHA="${{ github.event.before }}"
|
||||
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
fi
|
||||
HEAD_SHA="${{ github.sha }}"
|
||||
|
||||
if [ -n "$BASE_SHA" ]; then
|
||||
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
else
|
||||
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
|
||||
export BASE_SHA HEAD_SHA CHANGED_FILES
|
||||
node <<'NODE'
|
||||
const { execFileSync } = require('node:child_process')
|
||||
const fs = require('node:fs')
|
||||
const path = require('node:path')
|
||||
|
||||
const repoRoot = process.cwd()
|
||||
const baseSha = process.env.BASE_SHA || ''
|
||||
const headSha = process.env.HEAD_SHA || ''
|
||||
const files = (process.env.CHANGED_FILES || '').split(/\s+/).filter(Boolean)
|
||||
|
||||
const englishPath = fileStem => path.join(repoRoot, 'web', 'i18n', 'en-US', `${fileStem}.json`)
|
||||
|
||||
const readCurrentJson = (fileStem) => {
|
||||
const filePath = englishPath(fileStem)
|
||||
if (!fs.existsSync(filePath))
|
||||
return null
|
||||
|
||||
return JSON.parse(fs.readFileSync(filePath, 'utf8'))
|
||||
}
|
||||
|
||||
const readBaseJson = (fileStem) => {
|
||||
if (!baseSha)
|
||||
return null
|
||||
|
||||
try {
|
||||
const relativePath = `web/i18n/en-US/${fileStem}.json`
|
||||
const content = execFileSync('git', ['show', `${baseSha}:${relativePath}`], { encoding: 'utf8' })
|
||||
return JSON.parse(content)
|
||||
}
|
||||
catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const compareJson = (beforeValue, afterValue) => JSON.stringify(beforeValue) === JSON.stringify(afterValue)
|
||||
|
||||
const changes = {}
|
||||
|
||||
for (const fileStem of files) {
|
||||
const beforeJson = readBaseJson(fileStem) || {}
|
||||
const afterJson = readCurrentJson(fileStem) || {}
|
||||
const added = {}
|
||||
const updated = {}
|
||||
const deleted = []
|
||||
|
||||
for (const [key, value] of Object.entries(afterJson)) {
|
||||
if (!(key in beforeJson)) {
|
||||
added[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
if (!compareJson(beforeJson[key], value)) {
|
||||
updated[key] = {
|
||||
before: beforeJson[key],
|
||||
after: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const key of Object.keys(beforeJson)) {
|
||||
if (!(key in afterJson))
|
||||
deleted.push(key)
|
||||
}
|
||||
|
||||
changes[fileStem] = {
|
||||
fileDeleted: readCurrentJson(fileStem) === null,
|
||||
added,
|
||||
updated,
|
||||
deleted,
|
||||
}
|
||||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
'/tmp/i18n-changes.json',
|
||||
JSON.stringify({
|
||||
baseSha,
|
||||
headSha,
|
||||
files,
|
||||
changes,
|
||||
})
|
||||
)
|
||||
NODE
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "has_changes=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "has_changes=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
echo "base_sha=$BASE_SHA" >> "$GITHUB_OUTPUT"
|
||||
echo "head_sha=$HEAD_SHA" >> "$GITHUB_OUTPUT"
|
||||
echo "changed_files=$CHANGED_FILES" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
env:
|
||||
BASE_SHA: ${{ steps.detect.outputs.base_sha }}
|
||||
HEAD_SHA: ${{ steps.detect.outputs.head_sha }}
|
||||
CHANGED_FILES: ${{ steps.detect.outputs.changed_files }}
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const fs = require('fs')
|
||||
|
||||
const changesJson = fs.readFileSync('/tmp/i18n-changes.json', 'utf8')
|
||||
const changesBase64 = Buffer.from(changesJson).toString('base64')
|
||||
const maxEmbeddedChangesChars = 48000
|
||||
const changesEmbedded = changesBase64.length <= maxEmbeddedChangesChars
|
||||
|
||||
if (!changesEmbedded) {
|
||||
console.log(`Structured change set too large to embed safely (${changesBase64.length} chars). Downstream workflow will regenerate it from git history.`)
|
||||
}
|
||||
|
||||
await github.rest.repos.createDispatchEvent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
event_type: 'i18n-sync',
|
||||
client_payload: {
|
||||
changed_files: process.env.CHANGED_FILES,
|
||||
changes_base64: changesEmbedded ? changesBase64 : '',
|
||||
changes_embedded: changesEmbedded,
|
||||
sync_mode: 'incremental',
|
||||
base_sha: process.env.BASE_SHA,
|
||||
head_sha: process.env.HEAD_SHA,
|
||||
},
|
||||
})
|
||||
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
name: Run Full VDB Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-full-${{ github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
28
.github/workflows/vdb-tests.yml
vendored
28
.github/workflows/vdb-tests.yml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Run VDB Tests
|
||||
name: Run VDB Smoke Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
env:
|
||||
FORCE_JAVASCRIPT_ACTIONS_TO_NODE24: true
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
@@ -12,7 +12,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Tests
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -61,23 +61,18 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
@@ -89,4 +84,9 @@ jobs:
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
|
||||
6
.github/workflows/web-e2e.yml
vendored
6
.github/workflows/web-e2e.yml
vendored
@@ -27,12 +27,8 @@ jobs:
|
||||
- name: Setup web dependencies
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install E2E package dependencies
|
||||
working-directory: ./e2e
|
||||
run: vp install --frozen-lockfile
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
33
.github/workflows/web-tests.yml
vendored
33
.github/workflows/web-tests.yml
vendored
@@ -83,40 +83,9 @@ jobs:
|
||||
|
||||
- name: Report coverage
|
||||
if: ${{ env.CODECOV_TOKEN != '' }}
|
||||
uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3
|
||||
uses: codecov/codecov-action@57e3a136b779b570ffcdbf80b3bdc90e7fab3de2 # v6.0.0
|
||||
with:
|
||||
directory: web/coverage
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run build
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -212,6 +212,8 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
node_modules
|
||||
.vite-hooks/_
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
@@ -239,4 +241,4 @@ scripts/stress-test/reports/
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.qoder/*
|
||||
|
||||
8
web/.husky/pre-commit → .vite-hooks/pre-commit
Normal file → Executable file
8
web/.husky/pre-commit → .vite-hooks/pre-commit
Normal file → Executable file
@@ -77,7 +77,7 @@ if $web_modified; then
|
||||
fi
|
||||
|
||||
cd ./web || exit 1
|
||||
lint-staged
|
||||
vp staged
|
||||
|
||||
if $web_ts_modified; then
|
||||
echo "Running TypeScript type-check:tsgo"
|
||||
@@ -89,6 +89,12 @@ if $web_modified; then
|
||||
echo "No staged TypeScript changes detected, skipping type-check:tsgo"
|
||||
fi
|
||||
|
||||
echo "Running knip"
|
||||
if ! pnpm run knip; then
|
||||
echo "Knip check failed. Please run 'pnpm run knip' to fix the errors."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Running unit tests check"
|
||||
modified_files=$(git diff --cached --name-only -- utils | grep -v '\.spec\.ts$' || true)
|
||||
|
||||
6
Makefile
6
Makefile
@@ -24,8 +24,8 @@ prepare-docker:
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
|
||||
@pnpm install
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -93,7 +93,7 @@ test:
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker build -t $(WEB_IMAGE):$(VERSION) ./web
|
||||
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
|
||||
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
build-api:
|
||||
|
||||
@@ -115,12 +115,6 @@ ignore = [
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
"controllers/web/human_input_form.py" = ["TID251"]
|
||||
|
||||
[lint.pyflakes]
|
||||
allowed-unused-imports = [
|
||||
"tests.integration_tests",
|
||||
"tests.unit_tests",
|
||||
]
|
||||
|
||||
[lint.flake8-tidy-imports]
|
||||
|
||||
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
|
||||
|
||||
@@ -40,6 +40,8 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
|
||||
@@ -7,15 +7,16 @@ UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
DEFAULT_FILE_NUMBER_LIMITS = 3
|
||||
|
||||
IMAGE_EXTENSIONS = convert_to_lower_and_upper_set({"jpg", "jpeg", "png", "webp", "gif", "svg"})
|
||||
_IMAGE_EXTENSION_BASE: frozenset[str] = frozenset(("jpg", "jpeg", "png", "webp", "gif", "svg"))
|
||||
_VIDEO_EXTENSION_BASE: frozenset[str] = frozenset(("mp4", "mov", "mpeg", "webm"))
|
||||
_AUDIO_EXTENSION_BASE: frozenset[str] = frozenset(("mp3", "m4a", "wav", "amr", "mpga"))
|
||||
|
||||
VIDEO_EXTENSIONS = convert_to_lower_and_upper_set({"mp4", "mov", "mpeg", "webm"})
|
||||
IMAGE_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_IMAGE_EXTENSION_BASE))
|
||||
VIDEO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_VIDEO_EXTENSION_BASE))
|
||||
AUDIO_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_AUDIO_EXTENSION_BASE))
|
||||
|
||||
AUDIO_EXTENSIONS = convert_to_lower_and_upper_set({"mp3", "m4a", "wav", "amr", "mpga"})
|
||||
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = {
|
||||
_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
(
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@@ -35,11 +36,10 @@ if dify_config.ETL_TYPE == "Unstructured":
|
||||
"pptx",
|
||||
"xml",
|
||||
"epub",
|
||||
}
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = {
|
||||
)
|
||||
)
|
||||
_DEFAULT_DOCUMENT_EXTENSION_BASE: frozenset[str] = frozenset(
|
||||
(
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
@@ -53,8 +53,17 @@ else:
|
||||
"csv",
|
||||
"vtt",
|
||||
"properties",
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
)
|
||||
)
|
||||
|
||||
_doc_extensions: set[str]
|
||||
if dify_config.ETL_TYPE == "Unstructured":
|
||||
_doc_extensions = set(_UNSTRUCTURED_DOCUMENT_EXTENSION_BASE)
|
||||
if dify_config.UNSTRUCTURED_API_URL:
|
||||
_doc_extensions.add("ppt")
|
||||
else:
|
||||
_doc_extensions = set(_DEFAULT_DOCUMENT_EXTENSION_BASE)
|
||||
DOCUMENT_EXTENSIONS: frozenset[str] = frozenset(convert_to_lower_and_upper_set(_doc_extensions))
|
||||
|
||||
# console
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, TypeVar, final, runtime_checkable
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -188,8 +188,6 @@ class ExecutionContextBuilder:
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ContextProviderNotFoundError(KeyError):
|
||||
"""Raised when a tenant-scoped context provider is missing."""
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
@@ -11,7 +8,7 @@ class HiddenValue:
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
class RecyclableContextVar[T]:
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
|
||||
from models.model import IconType
|
||||
|
||||
JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
JSONObject: TypeAlias = dict[str, Any]
|
||||
type JSONValue = str | int | float | bool | None | dict[str, Any] | list[Any]
|
||||
type JSONObject = dict[str, Any]
|
||||
|
||||
|
||||
class SystemParameters(BaseModel):
|
||||
|
||||
@@ -4,8 +4,8 @@ from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
|
||||
HTML_MIME_TYPES = frozenset({"text/html", "application/xhtml+xml"})
|
||||
HTML_EXTENSIONS = frozenset({"html", "htm"})
|
||||
HTML_MIME_TYPES: frozenset[str] = frozenset(("text/html", "application/xhtml+xml"))
|
||||
HTML_EXTENSIONS: frozenset[str] = frozenset(("html", "htm"))
|
||||
|
||||
|
||||
def _normalize_mime_type(mime_type: str | None) -> str:
|
||||
|
||||
@@ -2,7 +2,6 @@ import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -20,9 +19,6 @@ from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@@ -72,9 +68,9 @@ console_ns.schema_model(
|
||||
)
|
||||
|
||||
|
||||
def admin_required(view: Callable[P, R]):
|
||||
def admin_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import flask_restx
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from flask_restx._http import HTTPStatus
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
@@ -34,7 +34,7 @@ api_key_list_model = console_ns.model(
|
||||
|
||||
|
||||
def _get_resource(resource_id, tenant_id, resource_model):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
resource = session.execute(
|
||||
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@@ -9,7 +9,7 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.file import helpers as file_helpers
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@@ -152,7 +152,7 @@ class AppTracePayload(BaseModel):
|
||||
return value
|
||||
|
||||
|
||||
JSONValue: TypeAlias = Any
|
||||
type JSONValue = Any
|
||||
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
@@ -642,7 +642,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@@ -655,7 +655,6 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import (
|
||||
@@ -87,7 +87,6 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
@@ -112,12 +111,11 @@ class AppImportConfirmApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@@ -134,7 +132,7 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@marshal_with(app_import_check_dependencies_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -69,7 +69,7 @@ class ConversationVariablesApi(Resource):
|
||||
page_size = 100
|
||||
stmt = stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
rows = session.scalars(stmt).all()
|
||||
|
||||
return {
|
||||
|
||||
@@ -9,8 +9,8 @@ from graphon.enums import NodeType
|
||||
from graphon.file import File
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -268,22 +268,18 @@ class DraftWorkflowApi(Resource):
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
payload_data: dict[str, Any] | None = None
|
||||
if "application/json" in content_type:
|
||||
payload_data = request.get_json(silent=True)
|
||||
if not isinstance(payload_data, dict):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
payload_data = json.loads(request.data.decode("utf-8"))
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
if not isinstance(payload_data, dict):
|
||||
args_model = SyncDraftWorkflowPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
args_model = SyncDraftWorkflowPayload.model_validate(payload_data)
|
||||
args = args_model.model_dump()
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@@ -840,7 +836,7 @@ class PublishedWorkflowApi(Resource):
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow = workflow_service.publish_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -858,8 +854,6 @@ class PublishedWorkflowApi(Resource):
|
||||
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
@@ -982,7 +976,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -1072,7 +1066,7 @@ class WorkflowByIdApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow = workflow_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
@@ -1084,9 +1078,6 @@ class WorkflowByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@@ -1101,13 +1092,11 @@ class WorkflowByIdApi(Resource):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session, workflow_id=workflow_id, tenant_id=app_model.tenant_id
|
||||
)
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@@ -87,7 +87,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
@@ -124,7 +124,7 @@ class WorkflowArchivedLogApi(Resource):
|
||||
args = WorkflowAppLogQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_archive_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, NoReturn, ParamSpec, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@@ -10,7 +10,7 @@ from graphon.variables.segment_group import SegmentGroup
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
@@ -192,11 +192,8 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
"WorkflowDraftVariableList", workflow_draft_variable_list_fields_copy
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def _api_prerequisite(f: Callable[P, R]):
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@@ -213,7 +210,7 @@ def _api_prerequisite(f: Callable[P, R]):
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -244,7 +241,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -270,7 +267,7 @@ class WorkflowVariableCollectionApi(Resource):
|
||||
return Response("", 204)
|
||||
|
||||
|
||||
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
def validate_node_id(node_id: str) -> None:
|
||||
if node_id in [
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
SYSTEM_VARIABLE_NODE_ID,
|
||||
@@ -285,7 +282,6 @@ def validate_node_id(node_id: str) -> NoReturn | None:
|
||||
raise InvalidArgumentError(
|
||||
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||
@@ -298,7 +294,7 @@ class NodeVariableCollectionApi(Resource):
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -465,7 +461,7 @@ class VariableResetApi(Resource):
|
||||
|
||||
|
||||
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@@ -64,7 +64,7 @@ class WebhookTriggerApi(Resource):
|
||||
|
||||
node_id = args.node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get webhook trigger for this app and node
|
||||
webhook_trigger = (
|
||||
session.query(WorkflowWebhookTrigger)
|
||||
@@ -95,7 +95,7 @@ class AppTriggersApi(Resource):
|
||||
assert isinstance(current_user, Account)
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Get all triggers for this app using select API
|
||||
triggers = (
|
||||
session.execute(
|
||||
@@ -137,7 +137,7 @@ class AppTriggerEnableApi(Resource):
|
||||
assert current_user.current_tenant_id is not None
|
||||
|
||||
trigger_id = args.trigger_id
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Find the trigger using select
|
||||
trigger = session.execute(
|
||||
select(AppTrigger).where(
|
||||
@@ -153,9 +153,6 @@ class AppTriggerEnableApi(Resource):
|
||||
# Update status based on enable_trigger boolean
|
||||
trigger.status = AppTriggerStatus.ENABLED if args.enable_trigger else AppTriggerStatus.DISABLED
|
||||
|
||||
session.commit()
|
||||
session.refresh(trigger)
|
||||
|
||||
# Add computed icon field
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
if trigger.trigger_type == "trigger-plugin":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
from typing import overload
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -9,11 +9,6 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import App, AppMode
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
P1 = ParamSpec("P1")
|
||||
R1 = TypeVar("R1")
|
||||
|
||||
|
||||
def _load_app_model(app_id: str) -> App | None:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -28,10 +23,30 @@ def _load_app_model_with_trial(app_id: str) -> App | None:
|
||||
return app_model
|
||||
|
||||
|
||||
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P1, R1]):
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P1.args, **kwargs: P1.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
@@ -69,10 +84,30 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R],
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def get_app_model_with_trial[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
mode: AppMode | list[AppMode] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("app_id"):
|
||||
raise ValueError("missing app_id in path parameters")
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import jsonify, request
|
||||
from flask.typing import ResponseReturnValue
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel
|
||||
@@ -16,10 +17,6 @@ from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType,
|
||||
|
||||
from .. import console_ns
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OAuthClientPayload(BaseModel):
|
||||
client_id: str
|
||||
@@ -39,9 +36,11 @@ class OAuthTokenRequest(BaseModel):
|
||||
refresh_token: str | None = None
|
||||
|
||||
|
||||
def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderApp, P], R]):
|
||||
def oauth_server_client_id_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
json_data = request.get_json()
|
||||
if json_data is None:
|
||||
raise BadRequest("client_id is required")
|
||||
@@ -58,9 +57,13 @@ def oauth_server_client_id_required(view: Callable[Concatenate[T, OAuthProviderA
|
||||
return decorated
|
||||
|
||||
|
||||
def oauth_server_access_token_required(view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R]):
|
||||
def oauth_server_access_token_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, OAuthProviderApp, Account, P], R],
|
||||
) -> Callable[Concatenate[T, OAuthProviderApp, P], R | ResponseReturnValue]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(
|
||||
self: T, oauth_provider_app: OAuthProviderApp, *args: P.args, **kwargs: P.kwargs
|
||||
) -> R | ResponseReturnValue:
|
||||
if not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class Subscription(Resource):
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = SubscriptionQuery.model_validate(request.args.to_dict(flat=True))
|
||||
BillingService.is_tenant_owner_or_admin(current_user)
|
||||
return BillingService.get_subscription(args.plan, args.interval, current_user.email, current_tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class ComplianceApi(Resource):
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
args = ComplianceDownloadQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
device_info = request.headers.get("User-Agent", "Unknown device")
|
||||
|
||||
@@ -6,7 +6,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_model
|
||||
@@ -158,10 +158,11 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, binding_id, action: Literal["enable", "disable"]):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
binding_id = str(binding_id)
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id)
|
||||
select(DataSourceOauthBinding).filter_by(id=binding_id, tenant_id=current_tenant_id)
|
||||
).scalar_one_or_none()
|
||||
if data_source_binding is None:
|
||||
raise NotFound("Data source binding not found.")
|
||||
@@ -211,7 +212,7 @@ class DataSourceNotionListApi(Resource):
|
||||
if not credential:
|
||||
raise NotFound("Credential not found.")
|
||||
exist_page_ids = []
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# import notion in the exist dataset
|
||||
if query.dataset_id:
|
||||
dataset = DatasetService.get_dataset(query.dataset_id)
|
||||
|
||||
@@ -173,8 +173,11 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
external_knowledge_api_id = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id, current_tenant_id
|
||||
)
|
||||
if external_knowledge_api is None:
|
||||
raise NotFound("API template not found.")
|
||||
|
||||
|
||||
@@ -120,7 +120,8 @@ class DatasourceOAuthCallback(Resource):
|
||||
if context is None:
|
||||
raise Forbidden("Invalid context_id")
|
||||
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@@ -141,7 +142,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
system_credentials=oauth_client_params,
|
||||
request=request,
|
||||
)
|
||||
credential_id = context.get("credential_id")
|
||||
credential_id: str | None = context.get("credential_id")
|
||||
if credential_id:
|
||||
datasource_provider_service.reauthorize_datasource_oauth_provider(
|
||||
tenant_id=tenant_id,
|
||||
@@ -150,7 +151,7 @@ class DatasourceOAuthCallback(Resource):
|
||||
name=oauth_response.metadata.get("name") or None,
|
||||
expire_at=oauth_response.expires_at,
|
||||
credentials=dict(oauth_response.credentials),
|
||||
credential_id=context.get("credential_id"),
|
||||
credential_id=credential_id,
|
||||
)
|
||||
else:
|
||||
datasource_provider_service.add_datasource_oauth_provider(
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@@ -85,7 +85,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
@@ -54,7 +54,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
yaml_content=payload.yaml_content,
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset(
|
||||
tenant_id=current_tenant_id,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from graphon.variables.types import SegmentType
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -55,7 +56,7 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite(f):
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@@ -70,7 +71,7 @@ def _api_prerequisite(f):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
@@ -96,7 +97,7 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -143,7 +144,7 @@ class RagPipelineNodeVariableCollectionApi(Resource):
|
||||
@marshal_with(workflow_draft_variable_list_model)
|
||||
def get(self, pipeline: Pipeline, node_id: str):
|
||||
validate_node_id(node_id)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
@@ -289,7 +290,7 @@ class RagPipelineVariableResetApi(Resource):
|
||||
|
||||
|
||||
def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session:
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=session,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import get_or_create_model, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
@@ -68,7 +68,7 @@ class RagPipelineImportApi(Resource):
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@@ -80,7 +80,6 @@ class RagPipelineImportApi(Resource):
|
||||
pipeline_id=payload.pipeline_id,
|
||||
dataset_name=payload.name,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
@@ -102,12 +101,11 @@ class RagPipelineImportConfirmApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@@ -124,7 +122,7 @@ class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_check_dependencies_model)
|
||||
def get(self, pipeline: Pipeline):
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
@@ -142,7 +140,7 @@ class RagPipelineExportApi(Resource):
|
||||
# Add include_secret params
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any, Literal, cast
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -186,29 +186,14 @@ class DraftRagPipelineApi(Resource):
|
||||
|
||||
if "application/json" in content_type:
|
||||
payload_dict = console_ns.payload or {}
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
if "graph" not in data or "features" not in data:
|
||||
raise ValueError("graph or features not found in data")
|
||||
|
||||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
payload_dict = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"rag_pipeline_variables": data.get("rag_pipeline_variables"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
payload = DraftWorkflowSyncPayload.model_validate_json(request.data)
|
||||
except (ValueError, ValidationError):
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
try:
|
||||
@@ -608,19 +593,15 @@ class PublishedRagPipelineApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
pipeline = session.merge(pipeline)
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
session.add(pipeline)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
session.commit()
|
||||
workflow = rag_pipeline_service.publish_workflow(
|
||||
session=db.session, # type: ignore[reportArgumentType,arg-type]
|
||||
pipeline=pipeline,
|
||||
account=current_user,
|
||||
)
|
||||
pipeline.is_published = True
|
||||
pipeline.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
@@ -695,7 +676,7 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflows, has_more = rag_pipeline_service.get_all_published_workflow(
|
||||
session=session,
|
||||
pipeline=pipeline,
|
||||
@@ -767,7 +748,7 @@ class RagPipelineByIdApi(Resource):
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
||||
# Create a session and manage the transaction
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
workflow = rag_pipeline_service.update_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
@@ -779,9 +760,6 @@ class RagPipelineByIdApi(Resource):
|
||||
if not workflow:
|
||||
raise NotFound("Workflow not found")
|
||||
|
||||
# Commit the transaction in the controller
|
||||
session.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
@setup_required
|
||||
@@ -798,14 +776,13 @@ class RagPipelineByIdApi(Resource):
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
try:
|
||||
workflow_service.delete_workflow(
|
||||
session=session,
|
||||
workflow_id=workflow_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
)
|
||||
session.commit()
|
||||
except WorkflowInUseError as e:
|
||||
abort(400, description=str(e))
|
||||
except DraftWorkflowDeletionError as e:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -9,13 +8,10 @@ from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.dataset import Pipeline
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def get_rag_pipeline(view_func: Callable[P, R]):
|
||||
def get_rag_pipeline[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not kwargs.get("pipeline_id"):
|
||||
raise ValueError("missing pipeline_id in path parameters")
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -74,7 +74,7 @@ class ConversationListApi(InstalledAppResource):
|
||||
try:
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import abort
|
||||
from flask_restx import Resource
|
||||
@@ -15,12 +15,8 @@ from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def installed_app_required[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -49,7 +45,7 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non
|
||||
return decorator
|
||||
|
||||
|
||||
def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def user_allowed_to_access_app[**P, R](view: Callable[Concatenate[InstalledApp, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[InstalledApp, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(installed_app: InstalledApp, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -73,7 +69,7 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def trial_app_required[**P, R](view: Callable[Concatenate[App, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
|
||||
@@ -106,7 +102,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
|
||||
return decorator
|
||||
|
||||
|
||||
def trial_feature_enable(view: Callable[P, R]):
|
||||
def trial_feature_enable[**P, R](view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -117,7 +113,7 @@ def trial_feature_enable(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def explore_banner_enabled(view: Callable[P, R]):
|
||||
def explore_banner_enabled[**P, R](view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def plugin_permission_required(
|
||||
install_required: bool = False,
|
||||
debug_required: bool = False,
|
||||
):
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def interceptor[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
user = current_user
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
permission = (
|
||||
session.query(TenantPluginPermission)
|
||||
.where(
|
||||
|
||||
@@ -8,7 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -519,7 +519,7 @@ class EducationAutoCompleteApi(Resource):
|
||||
@cloud_edition_billing_enabled
|
||||
@marshal_with(data_fields)
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = EducationAutocompleteQuery.model_validate(payload)
|
||||
|
||||
return BillingService.EducationIdentity.autocomplete(args.keywords, args.page, args.limit)
|
||||
@@ -562,7 +562,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
|
||||
user_email = current_user.email
|
||||
else:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
|
||||
if account is None:
|
||||
raise AccountNotFound()
|
||||
|
||||
@@ -99,7 +99,7 @@ class ModelProviderListApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserModelList.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
@@ -118,7 +118,7 @@ class ModelProviderCredentialApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
tenant_id = current_tenant_id
|
||||
# if credential_id is not provided, return current used credential
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = ParserCredentialId.model_validate(payload)
|
||||
|
||||
model_provider_service = ModelProviderService()
|
||||
|
||||
@@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=normalized_model_type,
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from flask import make_response, redirect, request, send_file
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -832,7 +832,8 @@ class ToolOAuthCallback(Resource):
|
||||
tool_provider = ToolProviderID(provider)
|
||||
plugin_id = tool_provider.plugin_id
|
||||
provider_name = tool_provider.provider_name
|
||||
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
|
||||
@@ -1018,7 +1019,7 @@ class ToolProviderMCPApi(Resource):
|
||||
|
||||
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
|
||||
validation_data = None
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
validation_data = service.get_provider_for_url_validation(
|
||||
tenant_id=current_tenant_id, provider_id=payload.provider_id
|
||||
@@ -1033,7 +1034,7 @@ class ToolProviderMCPApi(Resource):
|
||||
)
|
||||
|
||||
# Step 3: Perform database update in a transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
@@ -1060,7 +1061,7 @@ class ToolProviderMCPApi(Resource):
|
||||
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
|
||||
|
||||
@@ -1078,7 +1079,7 @@ class ToolMCPAuthApi(Resource):
|
||||
provider_id = payload.provider_id
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
if not db_provider:
|
||||
@@ -1099,7 +1100,7 @@ class ToolMCPAuthApi(Resource):
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
):
|
||||
# Update credentials in new transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.update_provider_credentials(
|
||||
provider_id=provider_id,
|
||||
@@ -1117,17 +1118,17 @@ class ToolMCPAuthApi(Resource):
|
||||
resource_metadata_url=e.resource_metadata_url,
|
||||
scope_hint=e.scope_hint,
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
response = service.execute_auth_actions(auth_result)
|
||||
return response
|
||||
except MCPRefreshTokenError as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
|
||||
except (MCPError, ValueError) as e:
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}") from e
|
||||
@@ -1140,7 +1141,7 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
|
||||
@@ -1154,7 +1155,7 @@ class ToolMCPListAllApi(Resource):
|
||||
def get(self):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
# Skip sensitive data decryption for list view to improve performance
|
||||
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
|
||||
@@ -1169,7 +1170,7 @@ class ToolMCPUpdateApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
_, tenant_id = current_account_with_tenant()
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
service = MCPToolManageService(session=session)
|
||||
tools = service.list_provider_tools(
|
||||
tenant_id=tenant_id,
|
||||
@@ -1187,7 +1188,7 @@ class ToolMCPCallbackApi(Resource):
|
||||
authorization_code = query.code
|
||||
|
||||
# Create service instance for handle_callback
|
||||
with Session(db.engine) as session, session.begin():
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
# handle_callback now returns state data and tokens
|
||||
state_data, tokens = handle_callback(state_key, authorization_code)
|
||||
|
||||
@@ -5,7 +5,7 @@ from flask import make_response, redirect, request
|
||||
from flask_restx import Resource
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -375,7 +375,7 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
assert user.current_tenant_id is not None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Delete trigger provider subscription
|
||||
TriggerProviderService.delete_trigger_provider(
|
||||
session=session,
|
||||
@@ -388,7 +388,6 @@ class TriggerSubscriptionDeleteApi(Resource):
|
||||
tenant_id=user.current_tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
except ValueError as e:
|
||||
raise BadRequest(str(e))
|
||||
@@ -499,9 +498,9 @@ class TriggerOAuthCallbackApi(Resource):
|
||||
provider_id = TriggerProviderID(provider)
|
||||
plugin_id = provider_id.plugin_id
|
||||
provider_name = provider_id.provider_name
|
||||
user_id = context.get("user_id")
|
||||
tenant_id = context.get("tenant_id")
|
||||
subscription_builder_id = context.get("subscription_builder_id")
|
||||
user_id: str = context["user_id"]
|
||||
tenant_id: str = context["tenant_id"]
|
||||
subscription_builder_id: str = context["subscription_builder_id"]
|
||||
|
||||
# Get OAuth client configuration
|
||||
oauth_client_params = TriggerProviderService.get_oauth_client(
|
||||
|
||||
@@ -155,7 +155,7 @@ class WorkspaceListApi(Resource):
|
||||
@setup_required
|
||||
@admin_required
|
||||
def get(self):
|
||||
payload = request.args.to_dict(flat=True) # type: ignore
|
||||
payload = request.args.to_dict(flat=True)
|
||||
args = WorkspaceListQuery.model_validate(payload)
|
||||
|
||||
stmt = select(Tenant).order_by(Tenant.created_at.desc())
|
||||
|
||||
@@ -4,7 +4,6 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import abort, request
|
||||
from sqlalchemy import select
|
||||
@@ -25,9 +24,6 @@ from services.operation_service import OperationService
|
||||
|
||||
from .error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Field names for decryption
|
||||
FIELD_NAME_PASSWORD = "password"
|
||||
FIELD_NAME_CODE = "code"
|
||||
@@ -37,7 +33,7 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check account initialization
|
||||
@@ -50,7 +46,7 @@ def account_initialization_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_cloud(view: Callable[P, R]):
|
||||
def only_edition_cloud[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
@@ -61,7 +57,7 @@ def only_edition_cloud(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_enterprise(view: Callable[P, R]):
|
||||
def only_edition_enterprise[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
@@ -72,7 +68,7 @@ def only_edition_enterprise(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def only_edition_self_hosted(view: Callable[P, R]):
|
||||
def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
@@ -83,7 +79,7 @@ def only_edition_self_hosted(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -95,7 +91,7 @@ def cloud_edition_billing_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str):
|
||||
def cloud_edition_billing_resource_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -137,7 +133,9 @@ def cloud_edition_billing_resource_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -160,7 +158,7 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](resource: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -196,7 +194,7 @@ def cloud_edition_billing_rate_limit_check(resource: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_utm_record(view: Callable[P, R]):
|
||||
def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
@@ -215,7 +213,7 @@ def cloud_utm_record(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# check setup
|
||||
@@ -229,7 +227,7 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_license_required(view: Callable[P, R]):
|
||||
def enterprise_license_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
@@ -241,7 +239,7 @@ def enterprise_license_required(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_password_login_enabled(view: Callable[P, R]):
|
||||
def email_password_login_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -254,7 +252,7 @@ def email_password_login_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def email_register_enabled(view: Callable[P, R]):
|
||||
def email_register_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -267,7 +265,7 @@ def email_register_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enable_change_email(view: Callable[P, R]):
|
||||
def enable_change_email[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
features = FeatureService.get_system_features()
|
||||
@@ -280,7 +278,7 @@ def enable_change_email(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
@@ -293,7 +291,7 @@ def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
def knowledge_pipeline_publish_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@@ -305,7 +303,7 @@ def knowledge_pipeline_publish_enabled(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def edit_permission_required(f: Callable[P, R]):
|
||||
def edit_permission_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -323,7 +321,7 @@ def edit_permission_required(f: Callable[P, R]):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
def is_admin_or_owner_required[**P, R](f: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(f)
|
||||
def decorated_function(*args: P.args, **kwargs: P.kwargs):
|
||||
from werkzeug.exceptions import Forbidden
|
||||
@@ -339,7 +337,7 @@ def is_admin_or_owner_required(f: Callable[P, R]):
|
||||
return decorated_function
|
||||
|
||||
|
||||
def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
def annotation_import_rate_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Rate limiting decorator for annotation import operations.
|
||||
|
||||
@@ -388,7 +386,7 @@ def annotation_import_rate_limit(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def annotation_import_concurrency_limit(view: Callable[P, R]):
|
||||
def annotation_import_concurrency_limit[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Concurrency control decorator for annotation import operations.
|
||||
|
||||
@@ -455,7 +453,7 @@ def _decrypt_field(field_name: str, error_class: type[Exception], error_message:
|
||||
payload[field_name] = decoded_value
|
||||
|
||||
|
||||
def decrypt_password_field(view: Callable[P, R]):
|
||||
def decrypt_password_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to decrypt password field in request payload.
|
||||
|
||||
@@ -477,7 +475,7 @@ def decrypt_password_field(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def decrypt_code_field(view: Callable[P, R]):
|
||||
def decrypt_code_field[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to decrypt verification code field in request payload.
|
||||
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.account import Tenant
|
||||
from models.model import DefaultEndUserSessionID, EndUser
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TenantUserPayload(BaseModel):
|
||||
tenant_id: str
|
||||
@@ -33,7 +29,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
user_model = None
|
||||
|
||||
if is_anonymous:
|
||||
@@ -56,7 +52,7 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
session_id=user_id,
|
||||
)
|
||||
session.add(user_model)
|
||||
session.commit()
|
||||
session.flush()
|
||||
session.refresh(user_model)
|
||||
|
||||
except Exception:
|
||||
@@ -65,9 +61,9 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
||||
return user_model
|
||||
|
||||
|
||||
def get_user_tenant(view_func: Callable[P, R]):
|
||||
def get_user_tenant[**P, R](view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
payload = TenantUserPayload.model_validate(request.get_json(silent=True) or {})
|
||||
|
||||
user_id = payload.user_id
|
||||
@@ -97,10 +93,14 @@ def get_user_tenant(view_func: Callable[P, R]):
|
||||
return decorated_view
|
||||
|
||||
|
||||
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
|
||||
def decorator(view_func: Callable[P, R]):
|
||||
def plugin_data[**P, R](
|
||||
view: Callable[P, R] | None = None,
|
||||
*,
|
||||
payload_type: type[BaseModel],
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
data = request.get_json()
|
||||
except Exception:
|
||||
|
||||
@@ -3,10 +3,7 @@ from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
@@ -14,9 +11,9 @@ from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
def billing_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -30,9 +27,9 @@ def billing_inner_api_only(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
def enterprise_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
@@ -46,9 +43,9 @@ def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
def enterprise_inner_api_user_auth[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.INNER_API:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@@ -82,9 +79,9 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]):
|
||||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
def plugin_inner_api_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from flask import Response
|
||||
from flask_restx import Resource
|
||||
from graphon.variables.input_entities import VariableEntity
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.mcp import mcp_ns
|
||||
@@ -67,7 +67,7 @@ class MCPAppApi(Resource):
|
||||
request_id: Union[int, str] | None = args.id
|
||||
mcp_request = self._parse_mcp_request(args.model_dump(exclude_none=True))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
# Get MCP server and app
|
||||
mcp_server, app = self._get_mcp_server_and_app(server_code, session)
|
||||
self._validate_server_status(mcp_server)
|
||||
@@ -174,6 +174,7 @@ class MCPAppApi(Resource):
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
@@ -188,7 +189,7 @@ class MCPAppApi(Resource):
|
||||
|
||||
def _retrieve_end_user(self, tenant_id: str, mcp_server_id: str) -> EndUser | None:
|
||||
"""Get end user - manages its own database session"""
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return (
|
||||
session.query(EndUser)
|
||||
.where(EndUser.tenant_id == tenant_id)
|
||||
@@ -228,9 +229,7 @@ class MCPAppApi(Resource):
|
||||
if not end_user and isinstance(mcp_request.root, mcp_types.InitializeRequest):
|
||||
client_info = mcp_request.root.params.clientInfo
|
||||
client_name = f"{client_info.name}@{client_info.version}"
|
||||
# Commit the session before creating end user to avoid transaction conflicts
|
||||
session.commit()
|
||||
with Session(db.engine, expire_on_commit=False) as create_session, create_session.begin():
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as create_session:
|
||||
end_user = self._create_end_user(client_name, app.tenant_id, app.id, mcp_server.id, create_session)
|
||||
|
||||
return handle_mcp_request(app, mcp_request, user_input_form, mcp_server, end_user, request_id)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
@@ -116,7 +116,7 @@ class ConversationApi(Resource):
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pagination = ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -8,7 +8,7 @@ from graphon.enums import WorkflowExecutionStatus
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -314,7 +314,7 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast, overload
|
||||
from typing import cast, overload
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
@@ -23,10 +24,6 @@ from services.api_token_service import ApiTokenCache, fetch_token_with_single_fl
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -46,16 +43,16 @@ class FetchUserArg(BaseModel):
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
def validate_app_token[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_app_token(
|
||||
def validate_app_token[**P, R](
|
||||
view: None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
|
||||
|
||||
|
||||
def validate_app_token(
|
||||
def validate_app_token[**P, R](
|
||||
view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None
|
||||
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[P, R]) -> Callable[P, R]:
|
||||
@@ -136,7 +133,10 @@ def validate_app_token(
|
||||
return decorator(view)
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_resource_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token(api_token_type)
|
||||
@@ -166,7 +166,10 @@ def cloud_edition_billing_resource_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_knowledge_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -188,7 +191,10 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
|
||||
return interceptor
|
||||
|
||||
|
||||
def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
def cloud_edition_billing_rate_limit_check[**P, R](
|
||||
resource: str,
|
||||
api_token_type: str,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def interceptor(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
@@ -225,99 +231,73 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
|
||||
return interceptor
|
||||
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ...
|
||||
def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
positional_parameters = [
|
||||
parameter
|
||||
for parameter in inspect.signature(view).parameters.values()
|
||||
if parameter.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
]
|
||||
expects_bound_instance = bool(positional_parameters and positional_parameters[0].name in {"self", "cls"})
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: object, **kwargs: object) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
@overload
|
||||
def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ...
|
||||
# Flask may pass URL path parameters positionally, so inspect both kwargs and args.
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
if not dataset_id and args:
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
def validate_dataset_token(
|
||||
view: Callable[Concatenate[T, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]:
|
||||
def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]:
|
||||
@wraps(view_func)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
|
||||
# First try to get from kwargs (explicit parameter)
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
||||
# If not in kwargs, try to extract from positional args
|
||||
if not dataset_id and args:
|
||||
# For class methods: args[0] is self, args[1] is dataset_id (if exists)
|
||||
# Check if first arg is likely a class instance (has __dict__ or __class__)
|
||||
if len(args) > 1 and hasattr(args[0], "__dict__"):
|
||||
# This is a class method, dataset_id should be in args[1]
|
||||
potential_id = args[1]
|
||||
# Validate it's a string-like UUID, not another object
|
||||
try:
|
||||
# Try to convert to string and check if it's a valid UUID format
|
||||
str_id = str(potential_id)
|
||||
# Basic check: UUIDs are 36 chars with hyphens
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from class method args")
|
||||
elif len(args) > 0:
|
||||
# Not a class method, check if args[0] looks like a UUID
|
||||
potential_id = args[0]
|
||||
try:
|
||||
str_id = str(potential_id)
|
||||
if len(str_id) == 36 and str_id.count("-") == 4:
|
||||
dataset_id = str_id
|
||||
except Exception:
|
||||
logger.exception("Failed to parse dataset_id from positional args")
|
||||
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
.limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
|
||||
tenant_account_join = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
.where(TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.role.in_(["owner"]))
|
||||
.where(Tenant.status == TenantStatus.NORMAL)
|
||||
).one_or_none() # TODO: only owner information is required, so only one is returned.
|
||||
if tenant_account_join:
|
||||
tenant, ta = tenant_account_join
|
||||
account = db.session.get(Account, ta.account_id)
|
||||
# Login admin
|
||||
if account:
|
||||
account.current_tenant = tenant
|
||||
current_app.login_manager._update_request_context_with_user(account) # type: ignore
|
||||
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type]
|
||||
raise Unauthorized("Tenant owner account does not exist.")
|
||||
else:
|
||||
raise Unauthorized("Tenant does not exist.")
|
||||
|
||||
return decorated
|
||||
if expects_bound_instance:
|
||||
if not args:
|
||||
raise TypeError("validate_dataset_token expected a bound resource instance.")
|
||||
return view(args[0], api_token.tenant_id, *args[1:], **kwargs)
|
||||
|
||||
if view:
|
||||
return decorator(view)
|
||||
return view(api_token.tenant_id, *args, **kwargs)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
return decorated
|
||||
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
|
||||
@@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound, RequestEntityTooLarge
|
||||
from controllers.trigger import bp
|
||||
from core.trigger.debug.event_bus import TriggerDebugEventBus
|
||||
from core.trigger.debug.events import WebhookDebugEvent, build_webhook_pool_key
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
from services.trigger.webhook_service import RawWebhookDataDict, WebhookService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,6 +23,7 @@ def _prepare_webhook_execution(webhook_id: str, is_debug: bool = False):
|
||||
webhook_id, is_debug=is_debug
|
||||
)
|
||||
|
||||
webhook_data: RawWebhookDataDict
|
||||
try:
|
||||
# Use new unified extraction and validation
|
||||
webhook_data = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
@@ -99,7 +99,7 @@ class ConversationListApi(WebApiResource):
|
||||
query = ConversationListQuery.model_validate(raw_args)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
pagination = WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
|
||||
@@ -4,7 +4,7 @@ import secrets
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.auth.error import (
|
||||
@@ -81,7 +81,7 @@ class ForgotPasswordSendEmailApi(Resource):
|
||||
else:
|
||||
language = "en-US"
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
|
||||
token = None
|
||||
if account is None:
|
||||
@@ -180,18 +180,17 @@ class ForgotPasswordResetApi(Resource):
|
||||
|
||||
email = reset_data.get("email", "")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if account:
|
||||
self._update_existing_account(account, password_hashed, salt, session)
|
||||
self._update_existing_account(account, password_hashed, salt)
|
||||
else:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
def _update_existing_account(self, account: Account, password_hashed, salt, session):
|
||||
def _update_existing_account(self, account: Account, password_hashed, salt):
|
||||
# Update existing account credentials
|
||||
account.password = base64.b64encode(password_hashed).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
session.commit()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
@@ -20,14 +20,13 @@ from services.enterprise.enterprise_service import EnterpriseService, WebAppSett
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[App, EndUser, P], R]):
|
||||
def validate_jwt_token[**P, R](
|
||||
view: Callable[Concatenate[App, EndUser, P], R] | None = None,
|
||||
) -> Callable[P, R] | Callable[[Callable[Concatenate[App, EndUser, P], R]], Callable[P, R]]:
|
||||
def decorator(view: Callable[Concatenate[App, EndUser, P], R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
app_model, end_user = decode_jwt_token()
|
||||
return view(app_model, end_user, *args, **kwargs)
|
||||
|
||||
@@ -38,7 +37,7 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
||||
return decorator
|
||||
|
||||
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
||||
def decode_jwt_token(app_code: str | None = None, user_id: str | None = None) -> tuple[App, EndUser]:
|
||||
system_features = FeatureService.get_system_features()
|
||||
if not app_code:
|
||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||
@@ -49,7 +48,7 @@ def decode_jwt_token(app_code: str | None = None, user_id: str | None = None):
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
app_model = session.scalar(select(App).where(App.id == app_id))
|
||||
site = session.scalar(select(Site).where(Site.code == app_code))
|
||||
if not app_model:
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -68,7 +68,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@@ -81,7 +81,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@@ -94,7 +94,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@@ -106,7 +106,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
@@ -239,7 +239,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
@@ -271,9 +271,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user: Account | EndUser,
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -359,7 +359,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Account | EndUser,
|
||||
args: LoopNodeRunPayload,
|
||||
streaming: bool = True,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -439,7 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@@ -451,7 +451,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -653,10 +653,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: ConversationSnapshot,
|
||||
message: MessageSnapshot,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
) -> ChatbotAppBlockingResponse | Generator[ChatbotAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -37,7 +37,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@@ -48,7 +48,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@@ -59,21 +59,21 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]: ...
|
||||
) -> Mapping | Generator[Mapping | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -36,7 +36,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@@ -46,7 +46,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@@ -56,20 +56,20 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -36,7 +36,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@@ -46,7 +46,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@@ -56,20 +56,20 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -244,10 +244,10 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
self,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
) -> Union[Mapping, Generator[Mapping | str, None, None]]:
|
||||
) -> Mapping | Generator[Mapping | str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
from typing import Any, Literal, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -62,7 +62,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@@ -77,7 +77,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@@ -92,28 +92,28 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None:
|
||||
# Add null check for dataset
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -278,7 +278,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
context: contextvars.Context,
|
||||
pipeline: Pipeline,
|
||||
workflow_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@@ -286,7 +286,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -302,7 +302,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
|
||||
workflow = db.session.get(Workflow, workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
@@ -624,10 +624,10 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
application_generate_entity: RagPipelineGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
@@ -668,7 +668,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
datasource_info: Mapping[str, Any],
|
||||
created_from: str,
|
||||
position: int,
|
||||
account: Union[Account, EndUser],
|
||||
account: Account | EndUser,
|
||||
batch: str,
|
||||
document_form: str,
|
||||
):
|
||||
@@ -715,7 +715,7 @@ class PipelineGenerator(BaseAppGenerator):
|
||||
pipeline: Pipeline,
|
||||
workflow: Workflow,
|
||||
start_node_id: str,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
"""
|
||||
Format datasource info list.
|
||||
|
||||
@@ -9,6 +9,7 @@ from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
from graphon.runtime import GraphRuntimeState, VariablePool
|
||||
from graphon.variable_loader import VariableLoader
|
||||
from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
@@ -84,13 +85,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
|
||||
user_id = None
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
|
||||
pipeline = db.session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
@@ -213,10 +214,10 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
workflow = db.session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
# return workflow
|
||||
@@ -297,10 +298,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
@@ -64,7 +64,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
@@ -82,7 +82,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
@@ -100,7 +100,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
@@ -110,14 +110,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
@@ -127,7 +127,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[Mapping[str, Any] | str, None, None]:
|
||||
with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@@ -237,7 +237,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@@ -245,7 +245,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
"""
|
||||
@@ -269,7 +269,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
@@ -280,7 +280,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
||||
@@ -609,10 +609,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
workflow: Workflow,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
user: Account | EndUser,
|
||||
draft_var_saver_factory: DraftVariableSaverFactory,
|
||||
stream: bool = False,
|
||||
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||
) -> WorkflowAppBlockingResponse | Generator[WorkflowAppStreamResponse, None, None]:
|
||||
"""
|
||||
Handle response.
|
||||
:param application_generate_entity: application generate entity
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
from typing import Annotated, Literal, Self
|
||||
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent
|
||||
@@ -27,7 +27,7 @@ class _AdvancedChatAppGenerateEntityWrapper(BaseModel):
|
||||
entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
_GenerateEntityUnion: TypeAlias = Annotated[
|
||||
type _GenerateEntityUnion = Annotated[
|
||||
_WorkflowGenerateEntityWrapper | _AdvancedChatAppGenerateEntityWrapper,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
@@ -72,14 +72,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
|
||||
_application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity
|
||||
_precomputed_event_type: StreamEvent | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity
|
||||
],
|
||||
application_generate_entity: ChatAppGenerateEntity | CompletionAppGenerateEntity | AgentChatAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
@@ -117,11 +115,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def process(
|
||||
self,
|
||||
) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
|
||||
]:
|
||||
) -> (
|
||||
ChatbotAppBlockingResponse
|
||||
| CompletionAppBlockingResponse
|
||||
| Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]
|
||||
):
|
||||
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
|
||||
# start generate conversation name thread
|
||||
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
|
||||
@@ -136,7 +134,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def _to_blocking_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]:
|
||||
) -> ChatbotAppBlockingResponse | CompletionAppBlockingResponse:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
@@ -148,7 +146,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
extras = {"usage": self._task_state.llm_result.usage.model_dump()}
|
||||
if self._task_state.metadata:
|
||||
extras["metadata"] = self._task_state.metadata.model_dump()
|
||||
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
|
||||
response: ChatbotAppBlockingResponse | CompletionAppBlockingResponse
|
||||
if self._conversation_mode == AppMode.COMPLETION:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -183,7 +181,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
def _to_stream_response(
|
||||
self, generator: Generator[StreamResponse, None, None]
|
||||
) -> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
) -> Generator[ChatbotAppStreamResponse | CompletionAppStreamResponse, None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
|
||||
@@ -5,14 +5,13 @@ This layer centralizes model-quota deduction outside node implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
from typing import TYPE_CHECKING, cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent
|
||||
from graphon.nodes.base.node import Node
|
||||
from typing_extensions import override
|
||||
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
|
||||
|
||||
@@ -8,8 +8,9 @@ associates with the node span.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final
|
||||
from typing import cast, final, override
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes, NodeType
|
||||
from graphon.graph_engine.layers import GraphEngineLayer
|
||||
@@ -17,7 +18,6 @@ from graphon.graph_events import GraphNodeEventBase
|
||||
from graphon.nodes.base.node import Node
|
||||
from opentelemetry import context as context_api
|
||||
from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.otel.parser import (
|
||||
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass(slots=True)
|
||||
class _NodeSpanContext:
|
||||
span: "Span"
|
||||
token: object
|
||||
token: Token[context_api.Context]
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -153,7 +153,7 @@ class DatasourceFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
|
||||
upload_file: UploadFile | None = db.session.get(UploadFile, id)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
@@ -171,7 +171,7 @@ class DatasourceFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
|
||||
message_file: MessageFile | None = db.session.get(MessageFile, id)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
@@ -185,7 +185,7 @@ class DatasourceFileManager:
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||
tool_file: ToolFile | None = db.session.get(ToolFile, tool_file_id)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@@ -203,7 +203,7 @@ class DatasourceFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file: UploadFile | None = db.session.get(UploadFile, upload_file_id)
|
||||
|
||||
if not upload_file:
|
||||
return None, None
|
||||
|
||||
@@ -44,7 +44,8 @@ class HumanInputContent(BaseModel):
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
# Keep a runtime alias here: callers and tests expect identity with HumanInputContent.
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent # noqa: UP040
|
||||
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
|
||||
@@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
@@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
@@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
credential_name=credential_name,
|
||||
)
|
||||
@@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
credential_id=credential.id,
|
||||
is_valid=True,
|
||||
)
|
||||
@@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
@@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_type == model_type,
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
return session.execute(stmt).scalars().first()
|
||||
@@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=True,
|
||||
)
|
||||
@@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=False,
|
||||
)
|
||||
@@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
load_balancing_config_count = session.execute(stmt).scalar() or 0
|
||||
@@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
@@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ class CSVSanitizer:
|
||||
"""
|
||||
|
||||
# Characters that can start a formula in Excel/LibreOffice/Google Sheets
|
||||
FORMULA_CHARS = frozenset({"=", "+", "-", "@", "\t", "\r"})
|
||||
FORMULA_CHARS = frozenset(("=", "+", "-", "@", "\t", "\r"))
|
||||
|
||||
@classmethod
|
||||
def sanitize_value(cls, value: Any) -> str:
|
||||
|
||||
@@ -2,12 +2,13 @@ import importlib.util
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import AnyStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_lazy_loader: bool = False) -> ModuleType:
|
||||
def import_module_from_source[T: (str, bytes)](
|
||||
*, module_name: str, py_file_path: T, use_lazy_loader: bool = False
|
||||
) -> ModuleType:
|
||||
"""
|
||||
Importing a module from the source file directly
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import TypeVar
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file_cached
|
||||
@@ -65,10 +64,7 @@ def pin_position_map(original_position_map: dict[str, int], pin_list: list[str])
|
||||
return position_map
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_filtered(
|
||||
def is_filtered[T](
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: T,
|
||||
@@ -97,11 +93,11 @@ def is_filtered(
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
def sort_by_position_map[T](
|
||||
position_map: dict[str, int],
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
) -> list[T]:
|
||||
"""
|
||||
Sort the objects by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
@@ -116,11 +112,11 @@ def sort_by_position_map(
|
||||
return sorted(data, key=lambda x: position_map.get(name_func(x), float("inf")))
|
||||
|
||||
|
||||
def sort_to_dict_by_position_map(
|
||||
def sort_to_dict_by_position_map[T](
|
||||
position_map: dict[str, int],
|
||||
data: list[T],
|
||||
name_func: Callable[[T], str],
|
||||
):
|
||||
) -> OrderedDict[str, T]:
|
||||
"""
|
||||
Sort the objects into a ordered dict by the position map.
|
||||
If the name of the object is not in the position map, it will be put at the end.
|
||||
|
||||
@@ -4,7 +4,7 @@ Proxy requests to avoid SSRF
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
@@ -20,8 +20,8 @@ SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
Headers: TypeAlias = dict[str, str]
|
||||
_HEADERS_ADAPTER = TypeAdapter(Headers)
|
||||
type Headers = dict[str, str]
|
||||
_HEADERS_ADAPTER: TypeAdapter[Headers] = TypeAdapter(Headers)
|
||||
|
||||
_SSL_VERIFIED_POOL_KEY = "ssrf:verified"
|
||||
_SSL_UNVERIFIED_POOL_KEY = "ssrf:unverified"
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from configs import dify_config
|
||||
@@ -78,7 +78,7 @@ class IndexingRunner:
|
||||
continue
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, requeried_document.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
@@ -95,7 +95,7 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||
current_user = db.session.get(Account, requeried_document.created_by)
|
||||
if not current_user:
|
||||
raise ValueError("no current user found")
|
||||
current_user.set_tenant_id(dataset.tenant_id)
|
||||
@@ -137,23 +137,24 @@ class IndexingRunner:
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, requeried_document.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
document_segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == requeried_document.id,
|
||||
)
|
||||
).all()
|
||||
|
||||
for document_segment in document_segments:
|
||||
db.session.delete(document_segment)
|
||||
if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
# delete child chunks
|
||||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.execute(delete(ChildChunk).where(ChildChunk.segment_id == document_segment.id))
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == requeried_document.dataset_process_rule_id)
|
||||
@@ -167,7 +168,7 @@ class IndexingRunner:
|
||||
text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict())
|
||||
|
||||
# transform
|
||||
current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first()
|
||||
current_user = db.session.get(Account, requeried_document.created_by)
|
||||
if not current_user:
|
||||
raise ValueError("no current user found")
|
||||
current_user.set_tenant_id(dataset.tenant_id)
|
||||
@@ -207,17 +208,18 @@ class IndexingRunner:
|
||||
return
|
||||
|
||||
# get dataset
|
||||
dataset = db.session.query(Dataset).filter_by(id=requeried_document.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, requeried_document.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get exist document_segment list and delete
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(dataset_id=dataset.id, document_id=requeried_document.id)
|
||||
.all()
|
||||
)
|
||||
document_segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.document_id == requeried_document.id,
|
||||
)
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
if document_segments:
|
||||
@@ -289,7 +291,7 @@ class IndexingRunner:
|
||||
|
||||
embedding_model_instance = None
|
||||
if dataset_id:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found.")
|
||||
if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}:
|
||||
@@ -652,24 +654,26 @@ class IndexingRunner:
|
||||
@staticmethod
|
||||
def _process_keyword_index(flask_app, dataset_id, document_id, documents):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
dataset = db.session.get(Dataset, dataset_id)
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
keyword = Keyword(dataset)
|
||||
keyword.create(documents)
|
||||
if dataset.indexing_technique != IndexTechniqueType.HIGH_QUALITY:
|
||||
document_ids = [document.metadata["doc_id"] for document in documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
)
|
||||
.values(
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
@@ -703,17 +707,19 @@ class IndexingRunner:
|
||||
)
|
||||
|
||||
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.status: SegmentStatus.COMPLETED,
|
||||
DocumentSegment.enabled: True,
|
||||
DocumentSegment.completed_at: naive_utc_now(),
|
||||
}
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(document_ids),
|
||||
DocumentSegment.status == SegmentStatus.INDEXING,
|
||||
)
|
||||
.values(
|
||||
status=SegmentStatus.COMPLETED,
|
||||
enabled=True,
|
||||
completed_at=naive_utc_now(),
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
@@ -734,10 +740,17 @@ class IndexingRunner:
|
||||
"""
|
||||
Update the document indexing status.
|
||||
"""
|
||||
count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
|
||||
count = (
|
||||
db.session.scalar(
|
||||
select(func.count())
|
||||
.select_from(DatasetDocument)
|
||||
.where(DatasetDocument.id == document_id, DatasetDocument.is_paused == True)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if count > 0:
|
||||
raise DocumentIsPausedError()
|
||||
document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
|
||||
document = db.session.get(DatasetDocument, document_id)
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedError()
|
||||
|
||||
@@ -745,7 +758,7 @@ class IndexingRunner:
|
||||
|
||||
if extra_update_params:
|
||||
update_params.update(extra_update_params)
|
||||
db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
|
||||
db.session.execute(update(DatasetDocument).where(DatasetDocument.id == document_id).values(update_params)) # type: ignore
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
@@ -753,7 +766,9 @@ class IndexingRunner:
|
||||
"""
|
||||
Update the document segment by document id.
|
||||
"""
|
||||
db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
|
||||
db.session.execute(
|
||||
update(DocumentSegment).where(DocumentSegment.document_id == dataset_document_id).values(update_params)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
def _transform(
|
||||
|
||||
@@ -3,13 +3,19 @@
|
||||
import logging
|
||||
import traceback
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import orjson
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class IdentityDict(TypedDict, total=False):
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
user_type: str
|
||||
|
||||
|
||||
class StructuredJSONFormatter(logging.Formatter):
|
||||
"""
|
||||
JSON log formatter following the specified schema:
|
||||
@@ -84,7 +90,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
|
||||
return log_dict
|
||||
|
||||
def _extract_identity(self, record: logging.LogRecord) -> dict[str, str] | None:
|
||||
def _extract_identity(self, record: logging.LogRecord) -> IdentityDict | None:
|
||||
tenant_id = getattr(record, "tenant_id", None)
|
||||
user_id = getattr(record, "user_id", None)
|
||||
user_type = getattr(record, "user_type", None)
|
||||
@@ -92,7 +98,7 @@ class StructuredJSONFormatter(logging.Formatter):
|
||||
if not any([tenant_id, user_id, user_type]):
|
||||
return None
|
||||
|
||||
identity: dict[str, str] = {}
|
||||
identity: IdentityDict = {}
|
||||
if tenant_id:
|
||||
identity["tenant_id"] = tenant_id
|
||||
if user_id:
|
||||
|
||||
@@ -3,7 +3,7 @@ import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, TypeAlias, final
|
||||
from typing import Any, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
@@ -33,9 +33,9 @@ class _StatusError:
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
type ReadQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
type WriteQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
type StatusQueue = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
class SSETransport:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -9,13 +9,12 @@ from core.mcp.types import LATEST_PROTOCOL_VERSION, OAuthClientInformation, OAut
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION]
|
||||
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT]):
|
||||
class RequestContext[SessionT: BaseSession[Any, Any, Any, Any, Any], LifespanContextT]:
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
|
||||
@@ -260,4 +260,12 @@ def convert_input_form_to_parameters(
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "number"
|
||||
elif item.type == VariableEntityType.CHECKBOX:
|
||||
parameters[item.variable]["type"] = "boolean"
|
||||
elif item.type == VariableEntityType.JSON_OBJECT:
|
||||
parameters[item.variable]["type"] = "object"
|
||||
if item.json_schema:
|
||||
for key in ("properties", "required", "additionalProperties"):
|
||||
if key in item.json_schema:
|
||||
parameters[item.variable][key] = item.json_schema[key]
|
||||
return parameters, required
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Callable
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Self, TypeVar
|
||||
from typing import Any, Self
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from pydantic import BaseModel
|
||||
@@ -34,16 +34,10 @@ from core.mcp.types import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
class RequestResponder[ReceiveRequestT: ClientRequest | ServerRequest, SendResultT: ClientResult | ServerResult]:
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
@@ -60,7 +54,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""
|
||||
|
||||
request: ReceiveRequestT
|
||||
_session: Any
|
||||
_session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]"
|
||||
_on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
|
||||
|
||||
def __init__(
|
||||
@@ -68,7 +62,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]""",
|
||||
session: "BaseSession[Any, Any, SendResultT, ReceiveRequestT, Any]",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
):
|
||||
self.request_id = request_id
|
||||
@@ -111,7 +105,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
|
||||
self.completed = True
|
||||
|
||||
self._session._send_response(request_id=self.request_id, response=response)
|
||||
self._session.send_response(request_id=self.request_id, response=response)
|
||||
|
||||
def cancel(self):
|
||||
"""Cancel this request and mark it as completed."""
|
||||
@@ -120,21 +114,19 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
|
||||
self.completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
self._session._send_response(
|
||||
self._session.send_response(
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
class BaseSession[
|
||||
SendRequestT: ClientRequest | ServerRequest,
|
||||
SendNotificationT: ClientNotification | ServerNotification,
|
||||
SendResultT: ClientResult | ServerResult,
|
||||
ReceiveRequestT: ClientRequest | ServerRequest,
|
||||
ReceiveNotificationT: ClientNotification | ServerNotification,
|
||||
]:
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
@@ -204,13 +196,13 @@ class BaseSession(
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
def send_request[T: BaseModel](
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
result_type: type[T],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata | None = None,
|
||||
) -> ReceiveResultT:
|
||||
) -> T:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
@@ -299,7 +291,7 @@ class BaseSession(
|
||||
)
|
||||
self._write_stream.put(session_message)
|
||||
|
||||
def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
||||
def send_response(self, request_id: RequestId, response: SendResultT | ErrorData):
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
@@ -350,7 +342,7 @@ class BaseSession(
|
||||
responder = RequestResponder[ReceiveRequestT, SendResultT](
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta if validated_request.root.params else None,
|
||||
request=validated_request,
|
||||
request=validated_request, # type: ignore[arg-type] # mypy can't narrow constrained TypeVar from model_validate
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
)
|
||||
@@ -372,8 +364,8 @@ class BaseSession(
|
||||
if cancelled_id in self._in_flight:
|
||||
self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
self._received_notification(notification)
|
||||
self._handle_incoming(notification)
|
||||
self._received_notification(notification) # type: ignore[arg-type]
|
||||
self._handle_incoming(notification) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
# For other validation errors, log and continue
|
||||
logger.warning("Failed to validate notification: %s. Message was: %s", e, message.message.root)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
|
||||
from pydantic.networks import AnyUrl, UrlConstraints
|
||||
@@ -31,7 +31,7 @@ ProgressToken = str | int
|
||||
Cursor = str
|
||||
Role = Literal["user", "assistant"]
|
||||
RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
|
||||
AnyFunction: TypeAlias = Callable[..., Any]
|
||||
type AnyFunction = Callable[..., Any]
|
||||
|
||||
|
||||
class RequestParams(BaseModel):
|
||||
@@ -68,12 +68,7 @@ class NotificationParams(BaseModel):
|
||||
"""
|
||||
|
||||
|
||||
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
|
||||
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
|
||||
MethodT = TypeVar("MethodT", bound=str)
|
||||
|
||||
|
||||
class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
class Request[RequestParamsT: RequestParams | dict[str, Any] | None, MethodT: str](BaseModel):
|
||||
"""Base class for JSON-RPC requests."""
|
||||
|
||||
method: MethodT
|
||||
@@ -81,14 +76,14 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]):
|
||||
class PaginatedRequest[T: str](Request[PaginatedRequestParams | None, T]):
|
||||
"""Base class for paginated requests,
|
||||
matching the schema's PaginatedRequest interface."""
|
||||
|
||||
params: PaginatedRequestParams | None = None
|
||||
|
||||
|
||||
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
|
||||
class Notification[NotificationParamsT: NotificationParams | dict[str, Any] | None, MethodT: str](BaseModel):
|
||||
"""Base class for JSON-RPC notifications."""
|
||||
|
||||
method: MethodT
|
||||
@@ -736,7 +731,7 @@ class ResourceLink(Resource):
|
||||
ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource
|
||||
"""A content block that can be used in prompts and tool results."""
|
||||
|
||||
Content: TypeAlias = ContentBlock
|
||||
type Content = ContentBlock
|
||||
# """DEPRECATED: Content is deprecated, you should use ContentBlock directly."""
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,13 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
@@ -45,10 +51,10 @@ class TraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
@@ -38,6 +38,7 @@ from core.ops.entities.trace_entity import (
|
||||
TraceTaskName,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import JSON_DICT_ADAPTER
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser, MessageFile
|
||||
@@ -134,10 +135,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
exception_attributes.EXCEPTION_TYPE: exception_type,
|
||||
exception_attributes.EXCEPTION_MESSAGE: exception_message,
|
||||
exception_attributes.EXCEPTION_ESCAPED: False,
|
||||
exception_attributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
@@ -469,7 +470,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
|
||||
|
||||
if trace_info.message_data and trace_info.message_data.message_metadata:
|
||||
metadata_dict = json.loads(trace_info.message_data.message_metadata)
|
||||
metadata_dict = JSON_DICT_ADAPTER.validate_json(trace_info.message_data.message_metadata)
|
||||
if model_params := metadata_dict.get("model_parameters"):
|
||||
llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user