mirror of
https://github.com/langgenius/dify.git
synced 2026-03-09 17:25:10 +00:00
Compare commits
117 Commits
feat/enter
...
feat/model
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3c98e417d | ||
|
|
a59c54b3e7 | ||
|
|
dfe389c017 | ||
|
|
b364b06e51 | ||
|
|
7737bdc699 | ||
|
|
65637fc6b7 | ||
|
|
be6f7b8712 | ||
|
|
b257e8ed44 | ||
|
|
176d3c8c3a | ||
|
|
ce0197b107 | ||
|
|
164cefc65c | ||
|
|
c72ac8a434 | ||
|
|
f6d80b9fa7 | ||
|
|
497feac48e | ||
|
|
e845fa7e6a | ||
|
|
8906ab8e52 | ||
|
|
bab7bd5ecc | ||
|
|
cfb02bceaf | ||
|
|
694ca840e1 | ||
|
|
03dcbeafdf | ||
|
|
2d979e2cec | ||
|
|
5cee7cf8ce | ||
|
|
bbfa28e8a7 | ||
|
|
6c19e75969 | ||
|
|
9970f4449a | ||
|
|
0c17823c8b | ||
|
|
49c6696d08 | ||
|
|
292c98a8f3 | ||
|
|
0e0a6ad043 | ||
|
|
456c95adb1 | ||
|
|
1abbaf9fd5 | ||
|
|
1a26e1669b | ||
|
|
02444af2e3 | ||
|
|
56038e3684 | ||
|
|
eb9341e7ec | ||
|
|
e40b31b9c4 | ||
|
|
b89ee4807f | ||
|
|
9907cf9e06 | ||
|
|
208a31719f | ||
|
|
3d1ef1f7f5 | ||
|
|
24b14e2c1a | ||
|
|
53f122f717 | ||
|
|
fced2f9e65 | ||
|
|
0c08c4016d | ||
|
|
ff4e4a8d64 | ||
|
|
948efa129f | ||
|
|
e371bfd676 | ||
|
|
6d612c0909 | ||
|
|
56e0dc0ae6 | ||
|
|
975eca00c3 | ||
|
|
f049bafcc3 | ||
|
|
dd9c526447 | ||
|
|
922dc71e36 | ||
|
|
f03ec7f671 | ||
|
|
29f275442d | ||
|
|
c9532ffd43 | ||
|
|
840dc33b8b | ||
|
|
cae58a0649 | ||
|
|
1752edc047 | ||
|
|
7471c32612 | ||
|
|
2d333bbbe5 | ||
|
|
4af6788ce0 | ||
|
|
24b072def9 | ||
|
|
909c8c3350 | ||
|
|
80e9c8bee0 | ||
|
|
15b7b304d2 | ||
|
|
61e2672b59 | ||
|
|
5f4ed4c6f6 | ||
|
|
4a1032c628 | ||
|
|
423c97a47e | ||
|
|
a7e3fb2e33 | ||
|
|
ce34937a1c | ||
|
|
ad9ac6978e | ||
|
|
57c1ba3543 | ||
|
|
d7a5af2b9a | ||
|
|
d45edffaa3 | ||
|
|
530515b6ef | ||
|
|
f13f0d1f9a | ||
|
|
b597d52c11 | ||
|
|
34c42fe666 | ||
|
|
dc109c99f0 | ||
|
|
223b9d89c1 | ||
|
|
dd119eb44f | ||
|
|
970493fa85 | ||
|
|
ab87ac333a | ||
|
|
b8b70da9ad | ||
|
|
77d81aebe8 | ||
|
|
deb4cd3ece | ||
|
|
648d9ef1f9 | ||
|
|
5ed4797078 | ||
|
|
62631658e9 | ||
|
|
22a4100dd7 | ||
|
|
0f7ed6f67e | ||
|
|
4d9fcbec57 | ||
|
|
4d7a9bc798 | ||
|
|
d6d04ed657 | ||
|
|
f594a71dae | ||
|
|
04e0ab7eda | ||
|
|
784bda9c86 | ||
|
|
1af1fb6913 | ||
|
|
1f0c36e9f7 | ||
|
|
455ae65025 | ||
|
|
d44682e957 | ||
|
|
8c4afc0c18 | ||
|
|
539cbcae6a | ||
|
|
8d257fea7c | ||
|
|
c3364ac350 | ||
|
|
f991644989 | ||
|
|
29e344ac8b | ||
|
|
1ad9305732 | ||
|
|
17f38f171d | ||
|
|
802088c8eb | ||
|
|
cad6d94491 | ||
|
|
621d0fb2c9 | ||
|
|
a92fb3244b | ||
|
|
97508f8d7b | ||
|
|
70e677a6ac |
33
.github/actions/setup-web/action.yml
vendored
Normal file
33
.github/actions/setup-web/action.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
16
.github/dependabot.yml
vendored
16
.github/dependabot.yml
vendored
@@ -24,6 +24,18 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
- dependency-name: "ky"
|
||||
- dependency-name: "tailwind-merge"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "tailwindcss"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-markdown"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-syntax-highlighter"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-window"
|
||||
update-types: ["version-update:semver-major"]
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
@@ -33,6 +45,9 @@ updates:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
eslint-group:
|
||||
patterns:
|
||||
- "*eslint*"
|
||||
npm-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
@@ -41,3 +56,4 @@ updates:
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
- "*eslint*"
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@@ -12,22 +12,22 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
- uses: actions/setup-python@v6
|
||||
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@@ -84,4 +84,14 @@ jobs:
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
|
||||
18
.github/workflows/build-push.yml
vendored
18
.github/workflows/build-push.yml
vendored
@@ -53,26 +53,26 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
@@ -113,21 +113,21 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
|
||||
12
.github/workflows/db-migration-test.yml
vendored
12
.github/workflows/db-migration-test.yml
vendored
@@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
6
.github/workflows/docker-build.yml
vendored
6
.github/workflows/docker-build.yml
vendored
@@ -32,13 +32,13 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -9,6 +9,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v6
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
5
.github/workflows/main-ci.yml
vendored
5
.github/workflows/main-ci.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
@@ -39,6 +39,7 @@ jobs:
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
4
.github/workflows/pyrefly-diff-comment.yml
vendored
4
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
8
.github/workflows/pyrefly-diff.yml
vendored
8
.github/workflows/pyrefly-diff.yml
vendored
@@ -17,12 +17,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@v8
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
with:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
|
||||
36
.github/workflows/style.yml
vendored
36
.github/workflows/style.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -67,36 +67,22 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -134,14 +120,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
@@ -152,7 +138,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@v8
|
||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@@ -21,12 +21,12 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@v6
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
18
.github/workflows/translate-i18n-claude.yml
vendored
18
.github/workflows/translate-i18n-claude.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -48,18 +48,10 @@ jobs:
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
install-dependencies: "false"
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
@@ -130,7 +122,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@v1
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
4
.github/workflows/trigger-i18n-sync.yml
vendored
4
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
event-type: i18n-sync
|
||||
|
||||
8
.github/workflows/vdb-tests.yml
vendored
8
.github/workflows/vdb-tests.yml
vendored
@@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
68
.github/workflows/web-tests.yml
vendored
68
.github/workflows/web-tests.yml
vendored
@@ -26,32 +26,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
@@ -70,28 +57,15 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
@@ -419,7 +393,7 @@ jobs:
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
@@ -435,36 +409,22 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v47
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
@@ -44,7 +44,6 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
@@ -114,7 +113,6 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
@@ -135,7 +133,6 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
|
||||
@@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
|
||||
console_ns.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
code="max_keys_exceeded",
|
||||
custom="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
||||
@@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -57,7 +56,7 @@ class ToolFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||
tool_file_manager = ToolFileManager()
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotCompletionAppError()
|
||||
raise NotChatAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@@ -10,28 +10,18 @@ from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db as global_db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
@@ -89,7 +79,7 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -132,7 +122,7 @@ class ToolFileManager:
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -157,7 +147,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
@@ -181,7 +171,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
@@ -225,7 +215,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
|
||||
@@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
|
||||
@@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -47,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -56,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -69,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from extensions.ext_database import db as global_db
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
@@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
self._http_client = http_client
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
return ToolFileManager()
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response = self._http_client.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
|
||||
@@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ celery_redis = Redis(
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,6 +3,7 @@ import math
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from celery import group
|
||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
if not any(acquired):
|
||||
continue
|
||||
|
||||
jobs = [
|
||||
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
|
||||
if is_locked
|
||||
]
|
||||
result = group(jobs).apply_async()
|
||||
enqueued = len(jobs)
|
||||
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
result,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from celery import group, shared_task
|
||||
from celery import current_app, group, shared_task
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
|
||||
with session_factory() as session:
|
||||
total_dispatched = 0
|
||||
|
||||
# Process in batches until we've handled all due schedules or hit the limit
|
||||
while True:
|
||||
due_schedules = _fetch_due_schedules(session)
|
||||
|
||||
if not due_schedules:
|
||||
break
|
||||
|
||||
dispatched_count = _process_schedules(session, due_schedules)
|
||||
total_dispatched += dispatched_count
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
dispatched_count = _process_schedules(session, due_schedules, producer)
|
||||
total_dispatched += dispatched_count
|
||||
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if (
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||
):
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
if total_dispatched > 0:
|
||||
logger.info("Total processed: %d dispatched", total_dispatched)
|
||||
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
|
||||
|
||||
|
||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
@@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
return list(due_schedules)
|
||||
|
||||
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
|
||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||
if not schedules:
|
||||
return 0
|
||||
@@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
|
||||
|
||||
if tasks_to_dispatch:
|
||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||
job.apply_async()
|
||||
job.apply_async(producer=producer)
|
||||
|
||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from celery import current_app, shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryTaskLike(Protocol):
|
||||
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
@@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
|
||||
) -> None:
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
@@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
|
||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": document_task.tenant_id,
|
||||
"dataset_id": document_task.dataset_id,
|
||||
"document_ids": document_task.document_ids,
|
||||
},
|
||||
producer=producer,
|
||||
)
|
||||
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from celery import group, shared_task
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -27,6 +28,11 @@ from services.file_service import FileService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunked(iterable: Sequence, size: int):
|
||||
it = iter(iterable)
|
||||
return iter(lambda: list(islice(it, size)), [])
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
@@ -83,16 +89,24 @@ def rag_pipeline_run_task(
|
||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for batch in chunked(next_file_ids, 100):
|
||||
jobs = []
|
||||
for next_file_id in batch:
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
|
||||
file_id = (
|
||||
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
|
||||
)
|
||||
|
||||
jobs.append(
|
||||
rag_pipeline_run_task.s(
|
||||
rag_pipeline_invoke_entities_file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if jobs:
|
||||
group(jobs).apply_async()
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
@@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once_with(
|
||||
tenant_id=next_task["tenant_id"],
|
||||
dataset_id=next_task["dataset_id"],
|
||||
document_ids=next_task["document_ids"],
|
||||
)
|
||||
# apply_async is used by implementation; assert it was called once with expected kwargs
|
||||
assert task_dispatch_spy.apply_async.call_count == 1
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
|
||||
assert call_kwargs == {
|
||||
"tenant_id": next_task["tenant_id"],
|
||||
"dataset_id": next_task["dataset_id"],
|
||||
"document_ids": next_task["document_ids"],
|
||||
}
|
||||
set_waiting_spy.assert_called_once()
|
||||
delete_key_spy.assert_not_called()
|
||||
|
||||
@@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_not_called()
|
||||
task_dispatch_spy.apply_async.assert_not_called()
|
||||
delete_key_spy.assert_called_once()
|
||||
|
||||
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
||||
@@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.delay.assert_called_once()
|
||||
task_dispatch_spy.apply_async.assert_called_once()
|
||||
|
||||
def test_sessions_close_on_successful_indexing(
|
||||
self,
|
||||
@@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == concurrency_limit
|
||||
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
|
||||
assert set_waiting_spy.call_count == concurrency_limit
|
||||
|
||||
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
||||
@@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.delay.call_count == 3
|
||||
assert task_dispatch_spy.apply_async.call_count == 3
|
||||
for index, expected_task in enumerate(ordered_tasks):
|
||||
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
|
||||
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
||||
"""Skip limit checks when billing feature is disabled."""
|
||||
|
||||
@@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
assert mock_task_func.apply_async.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
calls = mock_task_func.apply_async.call_args_list
|
||||
sent_kwargs = calls[0][1]["kwargs"]
|
||||
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
@@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.delay.assert_called_once()
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": ["waiting-doc-1"],
|
||||
}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant1_id,
|
||||
"dataset_id": dataset1_id,
|
||||
"document_ids": ["tenant1-doc-1"],
|
||||
}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group, pull 1 task a time by default
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
@@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
# Verify waiting tasks were processed via group.apply_async
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
@@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
# Verify only tenant1's waiting task was processed (via group)
|
||||
assert mock_group.return_value.apply_async.called
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert first_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
@@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
@@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
mock_delay.assert_called_once()
|
||||
assert mock_group.return_value.apply_async.called
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
|
||||
@@ -105,18 +105,26 @@ def app_model(
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
"""Mock for celery group() function that collects dispatched tasks.
|
||||
|
||||
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
|
||||
(e.g. producer) so production code can pass broker-related options without
|
||||
breaking tests.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
self.last_apply_async_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self) -> None:
|
||||
def apply_async(self, **kwargs: Any) -> None:
|
||||
# Accept arbitrary kwargs like producer to be compatible with Celery
|
||||
self._applied = True
|
||||
self.last_apply_async_kwargs = kwargs
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
|
||||
@@ -0,0 +1,817 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceAuth,
|
||||
DatasourceAuthDefaultApi,
|
||||
DatasourceAuthDeleteApi,
|
||||
DatasourceAuthListApi,
|
||||
DatasourceAuthOauthCustomClient,
|
||||
DatasourceAuthUpdateApi,
|
||||
DatasourceHardCodeAuthListApi,
|
||||
DatasourceOAuthCallback,
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-1",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-123",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
|
||||
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {"name": "test"}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_callback_missing_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_invalid_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {} # avatar + name missing
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"reauthorize_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "/oauth-callback" in response.location
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
side_effect=CredentialsValidateFailedError("invalid"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
"enable_oauth_custom_client": False,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
"name": "x" * 101,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
@@ -0,0 +1,143 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
DataSourceContentPreviewApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = preview_result
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=payload["inputs"],
|
||||
account=account,
|
||||
datasource_type=payload["datasource_type"],
|
||||
is_published=True,
|
||||
credential_id=payload["credential_id"],
|
||||
)
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
# datasource_type missing
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
CustomizedPipelineTemplateApi,
|
||||
PipelineTemplateDetailApi,
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [{"id": "t1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in&language=en-US"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == templates
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
template = {"id": "tpl-1"}
|
||||
|
||||
service = MagicMock()
|
||||
service.get_pipeline_template_detail.return_value = template
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response = method(api, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert response == {"result": "success"}
|
||||
@@ -0,0 +1,187 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
CreateEmptyRagPipelineDatasetApi,
|
||||
CreateRagPipelineDatasetApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
import_info = {"dataset_id": "ds-1"}
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@@ -0,0 +1,324 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
RagPipelineSystemVariableCollectionApi,
|
||||
RagPipelineVariableApi,
|
||||
RagPipelineVariableCollectionApi,
|
||||
RagPipelineVariableResetApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db():
|
||||
db = MagicMock()
|
||||
db.engine = MagicMock()
|
||||
db.session.return_value = MagicMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restx_config(app):
|
||||
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
|
||||
|
||||
|
||||
class TestRagPipelineVariableCollectionApi:
|
||||
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = True
|
||||
|
||||
# IMPORTANT: RESTX expects .variables
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
draft_srv = MagicMock()
|
||||
draft_srv.list_variables_without_values.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineNodeVariableCollectionApi:
|
||||
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_node_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
def test_get_variable_not_found(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(id="p1", tenant_id="t1")
|
||||
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
payload = {"value": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineVariableResetApi:
|
||||
def test_reset_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableResetApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
workflow = MagicMock()
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
srv.reset_variable.return_value = variable
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
|
||||
class TestSystemAndEnvironmentVariablesApi:
|
||||
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineSystemVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_system_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_environment_variables_success(self, app, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
env_var = MagicMock(
|
||||
id="e1",
|
||||
name="ENV",
|
||||
description="d",
|
||||
selector="s",
|
||||
value_type=MagicMock(value="string"),
|
||||
value="x",
|
||||
)
|
||||
|
||||
workflow = MagicMock(environment_variables=[env_var])
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
@@ -0,0 +1,329 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
RagPipelineImportApi,
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
RagPipelineImportConfirmApi,
|
||||
)
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
"yaml_content": "content",
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"status": "success"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"status": "failed"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
result.model_dump.return_value = {"status": "pending"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"ok": True}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"ok": False}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
result = MagicMock()
|
||||
result.model_dump.return_value = {"deps": []}
|
||||
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"deps": []}
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": {"yaml": "data"}}
|
||||
@@ -0,0 +1,688 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
DraftRagPipelineApi,
|
||||
DraftRagPipelineRunApi,
|
||||
PublishedAllRagPipelineApi,
|
||||
PublishedRagPipelineApi,
|
||||
PublishedRagPipelineRunApi,
|
||||
RagPipelineByIdApi,
|
||||
RagPipelineDatasourceVariableApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result == [{"id": "p1"}]
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?q={}"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "llm")
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
node_exec = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
service.get_node_last_run.return_value = node_exec
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
service.set_datasource_variables.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result is not None
|
||||
@@ -0,0 +1,444 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.datasets import data_source
|
||||
from controllers.console.datasets.data_source import (
|
||||
DataSourceApi,
|
||||
DataSourceNotionApi,
|
||||
DataSourceNotionDatasetSyncApi,
|
||||
DataSourceNotionDocumentSyncApi,
|
||||
DataSourceNotionListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_ctx():
|
||||
return (MagicMock(id="u1"), "tenant-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_tenant(tenant_ctx):
|
||||
with patch(
|
||||
"controllers.console.datasets.data_source.current_account_with_tenant",
|
||||
return_value=tenant_ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
with patch.object(
|
||||
type(data_source.db),
|
||||
"engine",
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDataSourceApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
binding = MagicMock(
|
||||
id="b1",
|
||||
provider="notion",
|
||||
created_at="now",
|
||||
disabled=False,
|
||||
source_info={},
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: [binding]),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0]["is_bound"] is True
|
||||
|
||||
def test_get_no_bindings(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: []),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"] == []
|
||||
|
||||
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "enable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is False
|
||||
|
||||
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "disable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is True
|
||||
|
||||
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "disable")
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
dataset = MagicMock(data_source_type="notion_import")
|
||||
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = [document]
|
||||
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
dataset = MagicMock(data_source_type="other_type")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"integration_secret": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.NotionExtractor",
|
||||
return_value=extractor,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "p1", "page")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_indexing_estimate_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"notion_info_list": [
|
||||
{
|
||||
"workspace_id": "w1",
|
||||
"credential_id": "c1",
|
||||
"pages": [{"page_id": "p1", "type": "page"}],
|
||||
}
|
||||
],
|
||||
"process_rule": {"rules": {}},
|
||||
"doc_form": "text_model",
|
||||
"doc_language": "English",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
|
||||
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
|
||||
return_value=[MagicMock(id="d1")],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_document_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
1926
api/tests/unit_tests/controllers/console/datasets/test_datasets.py
Normal file
1926
api/tests/unit_tests/controllers/console/datasets/test_datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,399 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.external import (
|
||||
BedrockRetrievalApi,
|
||||
ExternalApiTemplateApi,
|
||||
ExternalApiTemplateListApi,
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_external_dataset")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
user.is_dataset_editor = True
|
||||
user.has_edit_permission = True
|
||||
user.is_dataset_operator = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(mocker, current_user):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.external.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
api_item = MagicMock()
|
||||
api_item.to_dict.return_value = {"id": "1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
|
||||
def test_post_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_duplicate_name(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalApiTemplateApi:
|
||||
def test_get_not_found(self, app):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_api",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "api-id")
|
||||
|
||||
def test_delete_forbidden(self, app, current_user):
|
||||
current_user.has_edit_permission = False
|
||||
current_user.is_dataset_operator = False
|
||||
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "api-id")
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
dataset.embedding_available = False
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset.is_published = False
|
||||
dataset.enable_api = False
|
||||
dataset.enable_qa = False
|
||||
dataset.enable_vector_store = False
|
||||
dataset.vector_store_setting = None
|
||||
dataset.is_multimodal = False
|
||||
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset.tags = []
|
||||
dataset.external_knowledge_info = None
|
||||
dataset.external_retrieval_model = None
|
||||
dataset.doc_metadata = []
|
||||
dataset.icon_info = None
|
||||
|
||||
dataset.summary_index_setting = MagicMock()
|
||||
dataset.summary_index_setting.enable = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
):
|
||||
_, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_create_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApi:
|
||||
def test_hit_testing_dataset_not_found(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "dataset-id")
|
||||
|
||||
def test_hit_testing_success(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"query": "hello"}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
||||
patch.object(DatasetService, "check_dataset_permission"),
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
"external_retrieve",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp = method(api, "dataset-id")
|
||||
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestBedrockRetrievalApi:
|
||||
def test_bedrock_retrieval(self, app):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "hello",
|
||||
"knowledge_id": "kid",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetTestService,
|
||||
"knowledge_retrieval",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp, status = method()
|
||||
|
||||
assert status == 200
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApiAdvanced:
|
||||
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_get_with_pagination(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
def test_create_forbidden(self, app, mock_auth, current_user):
|
||||
"""Test creating external dataset without permission"""
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
current_user.is_dataset_editor = False
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "ek-1",
|
||||
"name": "new_dataset",
|
||||
"description": "A dataset",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
|
||||
"""Test hit testing on non-existent dataset"""
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
dataset = MagicMock()
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": {"type": "bm25"},
|
||||
"metadata_filtering_conditions": {"status": "active"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.HitTestingService.external_retrieve",
|
||||
return_value={"results": []},
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
|
||||
assert resp["results"] == []
|
||||
|
||||
|
||||
class TestBedrockRetrievalApiAdvanced:
|
||||
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "test",
|
||||
"knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
|
||||
side_effect=ValueError("Invalid settings"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
@@ -0,0 +1,160 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from controllers.console.datasets.hit_testing_base import HitTestingPayload
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_hit_testing")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass all decorators on the API method."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.login_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.account_initialization_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
|
||||
return_value=lambda *_: (lambda f: f),
|
||||
)
|
||||
|
||||
|
||||
class TestHitTestingApi:
|
||||
def test_hit_testing_success(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "what is vector search",
|
||||
"top_k": 3,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"perform_hit_testing",
|
||||
return_value={"query": "what is vector search", "records": []},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
|
||||
assert "query" in result
|
||||
assert "records" in result
|
||||
assert result["records"] == []
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
side_effect=NotFound("Dataset not found"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
side_effect=ValueError("Invalid parameters"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
method(api, dataset_id)
|
||||
@@ -0,0 +1,207 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import (
|
||||
DatasetsHitTestingBase,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account():
|
||||
acc = MagicMock(spec=Account)
|
||||
return acc
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_current_user(mocker, account):
|
||||
"""Patch current_user to a valid Account."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing_base.current_user",
|
||||
account,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
class TestGetAndValidateDataset:
|
||||
def test_success(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
):
|
||||
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
assert result == dataset
|
||||
|
||||
def test_dataset_not_found(self):
|
||||
with patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
def test_permission_denied(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
side_effect=services.errors.account.NoPermissionError("no access"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="no access"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
|
||||
class TestHitTestingArgsCheck:
|
||||
def test_args_check_called(self):
|
||||
args = {"query": "test"}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"hit_testing_args_check",
|
||||
) as check_mock:
|
||||
DatasetsHitTestingBase.hit_testing_args_check(args)
|
||||
|
||||
check_mock.assert_called_once_with(args)
|
||||
|
||||
|
||||
class TestParseArgs:
|
||||
def test_parse_args_success(self):
|
||||
payload = {"query": "hello"}
|
||||
|
||||
result = DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
assert result["query"] == "hello"
|
||||
|
||||
def test_parse_args_invalid(self):
|
||||
payload = {"query": "x" * 300}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
response = {
|
||||
"query": "hello",
|
||||
"records": [],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == "hello"
|
||||
assert result["records"] == []
|
||||
|
||||
def test_index_not_initialized(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=services.errors.index.IndexNotInitializedError(),
|
||||
):
|
||||
with pytest.raises(DatasetNotInitializedError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_provider_token_not_init(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ProviderTokenNotInitError("token missing"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_quota_exceeded(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=QuotaExceededError(),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_model_not_supported(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_llm_bad_request(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=LLMBadRequestError("bad request"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_invoke_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_value_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ValueError("bad args"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="bad args"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_unexpected_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=Exception("boom"),
|
||||
):
|
||||
with pytest.raises(InternalServerError, match="boom"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
@@ -0,0 +1,362 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.metadata import (
|
||||
DatasetMetadataApi,
|
||||
DatasetMetadataBuiltInFieldActionApi,
|
||||
DatasetMetadataBuiltInFieldApi,
|
||||
DatasetMetadataCreateApi,
|
||||
DocumentMetadataEditApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_dataset_metadata")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
ds = MagicMock()
|
||||
ds.id = "dataset-1"
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass setup/login/license decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.enterprise_license_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetMetadataCreateApi:
|
||||
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "author"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"create_metadata",
|
||||
return_value={"id": "m1", "name": "author"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["name"] == "author"
|
||||
|
||||
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
valid_payload = {
|
||||
"type": "string",
|
||||
"name": "author",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=valid_payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataGetApi:
|
||||
def test_get_metadata_success(self, app, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_dataset_metadatas",
|
||||
return_value=[{"id": "m1"}],
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_metadata_dataset_not_found(self, app, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataApi:
|
||||
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "updated-name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_metadata_name",
|
||||
return_value={"id": "m1", "name": "updated-name"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "updated-name"
|
||||
|
||||
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"delete_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldApi:
|
||||
def test_get_built_in_fields(self, app):
|
||||
api = DatasetMetadataBuiltInFieldApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_built_in_fields",
|
||||
return_value=["title", "source"],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["fields"] == ["title", "source"]
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldActionApi:
|
||||
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataBuiltInFieldActionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"enable_built_in_field",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, "enable")
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDocumentMetadataEditApi:
|
||||
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DocumentMetadataEditApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"operation": "add", "metadata": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataOperationData,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_documents_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
@@ -0,0 +1,233 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.datasets.website import (
|
||||
WebsiteCrawlApi,
|
||||
WebsiteCrawlStatusApi,
|
||||
)
|
||||
from services.website_service import (
|
||||
WebsiteCrawlApiRequest,
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
WebsiteService,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_website_crawl")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_auth_and_setup(mocker):
|
||||
"""Bypass setup/login/account decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestWebsiteCrawlApi:
|
||||
def test_crawl_success(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {"depth": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
return_value={"job_id": "job-1"},
|
||||
)
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["job_id"] == "job-1"
|
||||
|
||||
def test_crawl_invalid_payload(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "bad-url",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid payload"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
|
||||
method(api)
|
||||
|
||||
def test_crawl_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
side_effect=Exception("crawl failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWebsiteCrawlStatusApi:
|
||||
def test_get_status_success(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
return_value={"status": "completed"},
|
||||
)
|
||||
|
||||
result, status = method(api, job_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_get_status_invalid_provider(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid provider"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
|
||||
method(api, job_id)
|
||||
|
||||
def test_get_status_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
side_effect=Exception("status lookup failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
|
||||
method(api, job_id)
|
||||
117
api/tests/unit_tests/controllers/console/datasets/test_wraps.py
Normal file
117
api/tests/unit_tests/controllers/console/datasets/test_wraps.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
class TestGetRagPipeline:
|
||||
def test_missing_pipeline_id(self):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(ValueError, match="missing pipeline_id"):
|
||||
dummy_view()
|
||||
|
||||
def test_pipeline_not_found(self, mocker):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
def test_pipeline_found_and_injected(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = "pipeline-1"
|
||||
pipeline.tenant_id = "tenant-1"
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result is pipeline
|
||||
|
||||
def test_pipeline_id_removed_from_kwargs(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
assert "pipeline_id" not in kwargs
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
def test_pipeline_id_cast_to_string(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
def where_side_effect(*args, **kwargs):
|
||||
assert args[0].right.value == "123"
|
||||
return Mock(first=lambda: pipeline)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.side_effect = where_side_effect
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id=123)
|
||||
|
||||
assert result is pipeline
|
||||
@@ -0,0 +1,341 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountAvatarApi,
|
||||
AccountDeleteApi,
|
||||
AccountDeleteVerifyApi,
|
||||
AccountInitApi,
|
||||
AccountIntegrateApi,
|
||||
AccountInterfaceLanguageApi,
|
||||
AccountInterfaceThemeApi,
|
||||
AccountNameApi,
|
||||
AccountPasswordApi,
|
||||
AccountProfileApi,
|
||||
AccountTimezoneApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidAccountDeletionCodeError,
|
||||
)
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAccountInitApi:
|
||||
def test_init_success(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="inactive")
|
||||
payload = {
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
"invitation_code": "code123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_init_already_initialized(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="active")
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
):
|
||||
with pytest.raises(AccountAlreadyInitedError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountProfileApi:
|
||||
def test_get_profile_success(self, app):
|
||||
api = AccountProfileApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/profile"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountUpdateApis:
|
||||
@pytest.mark.parametrize(
|
||||
("api_cls", "payload"),
|
||||
[
|
||||
(AccountNameApi, {"name": "test"}),
|
||||
(AccountAvatarApi, {"avatar": "img.png"}),
|
||||
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
|
||||
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
|
||||
(AccountTimezoneApi, {"timezone": "UTC"}),
|
||||
],
|
||||
)
|
||||
def test_update_success(self, app, api_cls, payload):
|
||||
api = api_cls()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountPasswordApi:
|
||||
def test_password_success(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "old",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
def test_password_wrong_current(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "bad",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.update_account_password",
|
||||
side_effect=ServicePwdError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CurrentPasswordIncorrectError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountIntegrateApi:
|
||||
def test_get_integrates(self, app):
|
||||
api = AccountIntegrateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = MagicMock(id="acc1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
|
||||
):
|
||||
scalars_mock.return_value.all.return_value = []
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
|
||||
class TestAccountDeleteApi:
|
||||
def test_delete_verify_success(self, app):
|
||||
api = AccountDeleteVerifyApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
|
||||
return_value=("token", "1234"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_invalid_code(self, app):
|
||||
api = AccountDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"token": "t", "code": "x"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidAccountDeletionCodeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestChangeEmailApis:
|
||||
def test_check_email_code_invalid(self, app):
|
||||
api = ChangeEmailCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "a@test.com", "code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.get_change_email_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_reset_email_already_used(self, app):
|
||||
api = ChangeEmailResetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"new_email": "x@test.com", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
|
||||
):
|
||||
with pytest.raises(EmailAlreadyInUseError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCheckEmailUniqueApi:
|
||||
def test_email_unique_success(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "ok@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_email_in_freeze(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "x@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
|
||||
):
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
method(api)
|
||||
@@ -0,0 +1,139 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.error import AccountNotFound
|
||||
from controllers.console.workspace.agent_providers import (
|
||||
AgentProviderApi,
|
||||
AgentProviderListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAgentProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
providers = [{"name": "openai"}, {"name": "anthropic"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=providers,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == providers
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAgentProviderApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "openai"
|
||||
provider_data = {"name": "openai", "models": ["gpt-4"]}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=provider_data,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result == provider_data
|
||||
|
||||
def test_get_provider_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "unknown"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api, "openai")
|
||||
@@ -0,0 +1,305 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.workspace.endpoint import (
|
||||
EndpointCreateApi,
|
||||
EndpointDeleteApi,
|
||||
EndpointDisableApi,
|
||||
EndpointEnableApi,
|
||||
EndpointListApi,
|
||||
EndpointListForSinglePluginApi,
|
||||
EndpointUpdateApi,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_and_tenant():
|
||||
return MagicMock(id="u1"), "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_current_account(user_and_tenant):
|
||||
with patch(
|
||||
"controllers.console.workspace.endpoint.current_account_with_tenant",
|
||||
return_value=user_and_tenant,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {"a": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_permission_denied(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
|
||||
side_effect=PluginPermissionDeniedError("denied"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_create_validation_error(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "p1",
|
||||
"name": "",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListApi:
|
||||
def test_list_success(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
assert len(result["endpoints"]) == 1
|
||||
|
||||
def test_list_invalid_query(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=0&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListForSinglePluginApi:
|
||||
def test_list_for_plugin_success(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
|
||||
return_value=[{"id": "e1"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
|
||||
def test_list_for_plugin_missing_param(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_delete_invalid_payload(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_delete_service_failure(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "new-name",
|
||||
"settings": {"x": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_update_validation_error(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1", "settings": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_update_service_failure(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "n",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointEnableApi:
|
||||
def test_enable_success(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_enable_invalid_payload(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_enable_service_failure(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDisableApi:
|
||||
def test_disable_success(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_disable_invalid_payload(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@@ -0,0 +1,607 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
MemberNotInTenantError,
|
||||
NotOwnerError,
|
||||
OwnerTransferLimitError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.workspace.members import (
|
||||
DatasetOperatorMemberListApi,
|
||||
MemberCancelInviteApi,
|
||||
MemberInviteEmailApi,
|
||||
MemberListApi,
|
||||
MemberUpdateRoleApi,
|
||||
OwnerTransfer,
|
||||
OwnerTransferCheckApi,
|
||||
SendOwnerTransferEmailApi,
|
||||
)
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "m1"
|
||||
member.name = "Member"
|
||||
member.email = "member@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "admin"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
def test_invite_success(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invite_limit_exceeded(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = False
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
with pytest.raises(WorkspaceMembersLimitExceeded):
|
||||
method(api)
|
||||
|
||||
def test_invite_already_member(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=AccountAlreadyInTenantError(),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "success"
|
||||
|
||||
def test_invite_invalid_role(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "owner",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "invalid-role"
|
||||
|
||||
def test_invite_generic_exception(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, _ = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "failed"
|
||||
|
||||
|
||||
class TestMemberCancelInviteApi:
|
||||
def test_cancel_success(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_cancel_not_found(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
||||
def test_cancel_cannot_operate_self(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_cancel_no_permission(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
||||
def test_cancel_member_not_in_tenant(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestMemberUpdateRoleApi:
|
||||
def test_update_success(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.update_member_role"),
|
||||
):
|
||||
result = method(api, "id")
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_update_invalid_role(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "invalid-role"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api, "id")
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_update_member_not_found(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.members.current_account_with_tenant",
|
||||
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=None),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "id")
|
||||
|
||||
|
||||
class TestDatasetOperatorMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "op1"
|
||||
member.name = "Operator"
|
||||
member.email = "operator@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "operator"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestSendOwnerTransferEmailApi:
|
||||
def test_send_success(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(name="ws")
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_send_ip_limit(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
|
||||
):
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
method(api)
|
||||
|
||||
def test_send_not_owner(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
|
||||
):
|
||||
with pytest.raises(NotOwnerError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferCheckApi:
|
||||
def test_check_invalid_code(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_rate_limited(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
with pytest.raises(OwnerTransferLimitError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_email(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "b@test.com", "code": "x"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidEmailError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferApi:
|
||||
def test_transfer_self(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
):
|
||||
with pytest.raises(CannotTransferOwnerToSelfError):
|
||||
method(api, "1")
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api, "2")
|
||||
|
||||
def test_member_not_in_tenant(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com"},
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
|
||||
):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, "2")
|
||||
@@ -0,0 +1,388 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic_core import ValidationError
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.model_providers import (
|
||||
ModelProviderCredentialApi,
|
||||
ModelProviderCredentialSwitchApi,
|
||||
ModelProviderIconApi,
|
||||
ModelProviderListApi,
|
||||
ModelProviderPaymentCheckoutUrlApi,
|
||||
ModelProviderValidateApi,
|
||||
PreferredProviderTypeUpdateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
INVALID_UUID = "123"
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestModelProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?model_type=llm"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
|
||||
return_value=[{"name": "openai"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderCredentialApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={VALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
|
||||
return_value={"key": "value"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_get_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_post_create_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}, "name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 201
|
||||
|
||||
def test_post_create_validation_error(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_put_update_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_put_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderCredentialSwitchApi:
|
||||
def test_switch_success(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderValidateApi:
|
||||
def test_validate_success(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_validate_failure(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestModelProviderIconApi:
|
||||
def test_icon_success(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(b"123", "image/png"),
|
||||
),
|
||||
):
|
||||
response = api.get("t1", "openai", "logo", "en")
|
||||
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_icon_not_found(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(None, None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
api.get("t1", "openai", "logo", "en")
|
||||
|
||||
|
||||
class TestPreferredProviderTypeUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "custom"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invalid_enum(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderPaymentCheckoutUrlApi:
|
||||
def test_checkout_success(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
|
||||
return_value={"url": "x"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="anthropic")
|
||||
|
||||
assert "url" in result
|
||||
|
||||
def test_invalid_provider(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_permission_denied(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
side_effect=Forbidden(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, provider="anthropic")
|
||||
@@ -0,0 +1,447 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.workspace.models import (
|
||||
DefaultModelApi,
|
||||
ModelProviderAvailableModelApi,
|
||||
ModelProviderModelApi,
|
||||
ModelProviderModelCredentialApi,
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
ModelProviderModelDisableApi,
|
||||
ModelProviderModelEnableApi,
|
||||
ModelProviderModelParameterRuleApi,
|
||||
ModelProviderModelValidateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDefaultModelApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"model_type": ModelType.LLM.value},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model_settings": [
|
||||
{
|
||||
"model_type": ModelType.LLM.value,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_get_returns_empty_when_no_default(self, app):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_default_model_of_model_type.return_value = None
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelApi:
|
||||
def test_get_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"load_balancing": {
|
||||
"configs": [{"weight": 1}],
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_model_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
def test_get_models_returns_empty(self, app):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialApi:
|
||||
def test_get_credentials_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
|
||||
):
|
||||
provider_service.return_value.get_model_credential.return_value = {
|
||||
"credentials": {},
|
||||
"current_credential_id": None,
|
||||
"current_credential_name": None,
|
||||
}
|
||||
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_create_credential_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_get_empty_credentials(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
|
||||
):
|
||||
service.return_value.get_model_credential.return_value = None
|
||||
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["credentials"] == {}
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialSwitchApi:
|
||||
def test_switch_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "abc",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelEnableDisableApis:
|
||||
def test_enable_model(self, app: Flask):
|
||||
api = ModelProviderModelEnableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_disable_model(self, app: Flask):
|
||||
api = ModelProviderModelDisableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelProviderModelValidateApi:
|
||||
def test_validate_success(self, app: Flask):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
|
||||
def test_validate_failure(self, app: Flask, model_name: str):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestParameterAndAvailableModels:
|
||||
def test_parameter_rules(self, app: Flask):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt-4"}),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_available_models(self, app: Flask):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_empty_rules(self, app):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt"}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
def test_no_models(self, app):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert result["data"] == []
|
||||
1019
api/tests/unit_tests/controllers/console/workspace/test_plugin.py
Normal file
1019
api/tests/unit_tests/controllers/console/workspace/test_plugin.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
ToolApiListApi,
|
||||
ToolApiProviderAddApi,
|
||||
ToolApiProviderDeleteApi,
|
||||
ToolApiProviderGetApi,
|
||||
ToolApiProviderGetRemoteSchemaApi,
|
||||
ToolApiProviderListToolsApi,
|
||||
ToolApiProviderUpdateApi,
|
||||
ToolBuiltinListApi,
|
||||
ToolBuiltinProviderAddApi,
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
ToolBuiltinProviderDeleteApi,
|
||||
ToolBuiltinProviderGetCredentialInfoApi,
|
||||
ToolBuiltinProviderGetCredentialsApi,
|
||||
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||
ToolBuiltinProviderIconApi,
|
||||
ToolBuiltinProviderInfoApi,
|
||||
ToolBuiltinProviderListToolsApi,
|
||||
ToolBuiltinProviderSetDefaultApi,
|
||||
ToolBuiltinProviderUpdateApi,
|
||||
ToolLabelsApi,
|
||||
ToolOAuthCallback,
|
||||
ToolOAuthCustomClient,
|
||||
ToolPluginOAuthApi,
|
||||
ToolProviderListApi,
|
||||
ToolProviderMCPApi,
|
||||
ToolWorkflowListApi,
|
||||
ToolWorkflowProviderCreateApi,
|
||||
ToolWorkflowProviderDeleteApi,
|
||||
ToolWorkflowProviderGetApi,
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
@@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
||||
|
||||
class TestUtils:
|
||||
def test_is_valid_url(self):
|
||||
assert is_valid_url("https://example.com")
|
||||
assert is_valid_url("http://example.com")
|
||||
assert not is_valid_url("")
|
||||
assert not is_valid_url("ftp://example.com")
|
||||
assert not is_valid_url("not-a-url")
|
||||
assert not is_valid_url(None)
|
||||
|
||||
|
||||
class TestToolProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
|
||||
return_value=["p1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["p1"]
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
def test_list_tools(self, app):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
|
||||
return_value=[{"a": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == [{"a": 1}]
|
||||
|
||||
def test_info(self, app):
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"x": 1}
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_id": "cid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["result"] == "success"
|
||||
|
||||
def test_add_invalid_type(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "provider")
|
||||
|
||||
def test_add_success(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["id"] == 1
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credentials(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
|
||||
return_value={"k": "v"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"k": "v"}
|
||||
|
||||
def test_icon(self, app):
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
|
||||
return_value=(b"x", "image/png"),
|
||||
),
|
||||
):
|
||||
response = method(api, "provider")
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_credentials_schema(self, app):
|
||||
api = ToolBuiltinProviderCredentialsSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider", "oauth2") == {"schema": {}}
|
||||
|
||||
def test_set_default_credential(self, app):
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"id": "c1"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credential_info(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
|
||||
return_value={"info": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"info": "x"}
|
||||
|
||||
def test_get_oauth_client_schema(self, app):
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"schema": {}}
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
def test_add(self, app):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_remote_schema(self, app):
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?url=http://x.com"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
|
||||
return_value={"schema": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api)["schema"] == "x"
|
||||
|
||||
def test_list_tools(self, app):
|
||||
api = ToolApiProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
|
||||
return_value=[{"tool": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"tool": 1}]
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolApiProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"original_provider": "o",
|
||||
"icon": {},
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolApiProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"provider": "p"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api)["result"] == "success"
|
||||
|
||||
def test_get(self, app):
|
||||
api = ToolApiProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api) == {"x": 1}
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
def test_create(self, app):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "n",
|
||||
"label": "l",
|
||||
"description": "d",
|
||||
"icon": {},
|
||||
"parameters": [],
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_update_invalid(self, app):
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "Tool",
|
||||
"label": "Tool Label",
|
||||
"description": "A tool",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_get_error(self, app):
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestLists:
|
||||
def test_builtin_list(self, app):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_api_list(self, app):
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_workflow_list(self, app):
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
|
||||
class TestLabels:
|
||||
def test_labels(self, app):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
|
||||
return_value=["l1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["l1"]
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
def test_oauth_no_client(self, app):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
def test_oauth_callback_no_cookie(self, app):
|
||||
api = ToolOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
def test_save_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"client_params": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
|
||||
return_value={"client_id": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"client_id": "x"}
|
||||
|
||||
def test_delete_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
@@ -0,0 +1,558 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console.workspace.trigger_providers import (
|
||||
TriggerOAuthAuthorizeApi,
|
||||
TriggerOAuthCallbackApi,
|
||||
TriggerOAuthClientManageApi,
|
||||
TriggerProviderIconApi,
|
||||
TriggerProviderInfoApi,
|
||||
TriggerProviderListApi,
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
TriggerSubscriptionDeleteApi,
|
||||
TriggerSubscriptionListApi,
|
||||
TriggerSubscriptionUpdateApi,
|
||||
TriggerSubscriptionVerifyApi,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def mock_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "u1"
|
||||
user.current_tenant_id = "t1"
|
||||
return user
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
def test_icon_success(self, app):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
|
||||
return_value="icon",
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == "icon"
|
||||
|
||||
def test_list_providers(self, app):
|
||||
api = TriggerProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api) == []
|
||||
|
||||
def test_provider_info(self, app):
|
||||
api = TriggerProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
|
||||
return_value={"id": "p1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"id": "p1"}
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
def test_list_success(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == []
|
||||
|
||||
def test_list_invalid_provider(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, "bad")
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
def test_create_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "subscription_builder" in result
|
||||
|
||||
def test_get_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_verify_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {"a": 1}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"ok": True}
|
||||
|
||||
def test_verify_builder_error(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
side_effect=Exception("err"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github", "b1")
|
||||
|
||||
def test_update_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "n"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_logs(self, app):
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
log = MagicMock()
|
||||
log.model_dump.return_value = {"a": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
|
||||
return_value=[log],
|
||||
),
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
|
||||
def test_build(self, app):
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == 200
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
def test_update_rename_only(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.UNAUTHORIZED
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_update_not_found(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "x")
|
||||
|
||||
def test_update_rebuild(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.OAUTH2
|
||||
sub.credentials = {}
|
||||
sub.parameters = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
|
||||
),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_delete_subscription(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = method(api, "sub1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_subscription_value_error(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
session_cls.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "sub1")
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
def test_oauth_authorize_success(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value=MagicMock(id="b1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
|
||||
return_value="ctx",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
|
||||
return_value=MagicMock(authorization_url="url"),
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_oauth_authorize_no_client(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_forbidden(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_success(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 302
|
||||
|
||||
def test_oauth_callback_no_oauth_client(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_empty_credentials(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials=None, expires_at=None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
def test_get_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "configured" in result
|
||||
|
||||
def test_post_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_delete_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
def test_verify_success(self, app):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "s1") == {"ok": True}
|
||||
|
||||
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
|
||||
def test_verify_errors(self, app, raised_exception):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
side_effect=raised_exception,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github", "s1")
|
||||
@@ -0,0 +1,605 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.workspace.workspace import (
|
||||
CustomConfigWorkspaceApi,
|
||||
SwitchWorkspaceApi,
|
||||
TenantApi,
|
||||
TenantListApi,
|
||||
WebappLogoWorkspaceApi,
|
||||
WorkspaceInfoApi,
|
||||
WorkspaceListApi,
|
||||
WorkspacePermissionApi,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
|
||||
def test_get_billing_disabled(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=False,
|
||||
total=1,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
|
||||
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["total"] == 1
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_has_next_true(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="T",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=True,
|
||||
total=10,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.paginate",
|
||||
return_value=paginate_result,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["has_more"] is True
|
||||
|
||||
|
||||
class TestTenantApi:
|
||||
def test_post_active_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["id"] == "t1"
|
||||
|
||||
def test_post_archived_with_switch(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
archived = MagicMock(status=TenantStatus.ARCHIVE)
|
||||
new_tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=archived)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["id"] == "new"
|
||||
|
||||
def test_post_archived_no_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
|
||||
):
|
||||
with pytest.raises(Unauthorized):
|
||||
method(api)
|
||||
|
||||
def test_post_info_path(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/info"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(user, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
warn_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestSwitchWorkspaceApi:
|
||||
def test_switch_success(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "t2"}
|
||||
tenant = MagicMock(id="t2")
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_not_linked(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "bad"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
|
||||
):
|
||||
with pytest.raises(AccountNotLinkTenantError):
|
||||
method(api)
|
||||
|
||||
def test_switch_tenant_not_found(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "missing"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCustomConfigWorkspaceApi:
|
||||
def test_post_success(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={})
|
||||
|
||||
payload = {"remove_webapp_brand": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_logo_fallback(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
|
||||
|
||||
payload = {"remove_webapp_brand": False}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.get_or_404",
|
||||
return_value=tenant,
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestWebappLogoWorkspaceApi:
|
||||
def test_no_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
method(api)
|
||||
|
||||
def test_too_many_files(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": MagicMock(),
|
||||
"extra": MagicMock(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data=data),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(TooManyFilesError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_extension(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = MagicMock(filename="test.txt")
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={"file": file}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
def test_upload_success(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
upload = MagicMock(id="file1")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.return_value = upload
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file1"
|
||||
|
||||
def test_filename_missing(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
method(api)
|
||||
|
||||
def test_file_too_large(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
method(api)
|
||||
|
||||
def test_service_unsupported_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspaceInfoApi:
|
||||
def test_post_success(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
|
||||
payload = {"name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"name": "New Name"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "X"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspacePermissionApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
permission = MagicMock(
|
||||
workspace_id="t1",
|
||||
allow_member_invite=True,
|
||||
allow_owner_transfer=False,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
|
||||
return_value=permission,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspace_id"] == "t1"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, permission):
|
||||
self._permission = permission
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._permission
|
||||
|
||||
|
||||
def _workspace_module():
|
||||
return importlib.import_module(plugin_permission_required.__module__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, None)
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
0
api/tests/unit_tests/controllers/web/__init__.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
85
api/tests/unit_tests/controllers/web/conftest.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Shared fixtures for controllers.web unit tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Minimal Flask app for request contexts."""
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class FakeSession:
|
||||
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: FakeSession | None = None):
|
||||
self.session = session or FakeSession()
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def make_app_model(
|
||||
*,
|
||||
app_id: str = "app-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
mode: str = "chat",
|
||||
enable_site: bool = True,
|
||||
status: str = "normal",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake App model with common defaults."""
|
||||
tenant = SimpleNamespace(
|
||||
id=tenant_id,
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
tenant=tenant,
|
||||
mode=mode,
|
||||
enable_site=enable_site,
|
||||
status=status,
|
||||
workflow=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
|
||||
def make_end_user(
|
||||
*,
|
||||
user_id: str = "end-user-1",
|
||||
session_id: str = "session-1",
|
||||
external_user_id: str = "ext-user-1",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake EndUser model with common defaults."""
|
||||
return SimpleNamespace(
|
||||
id=user_id,
|
||||
session_id=session_id,
|
||||
external_user_id=external_user_id,
|
||||
)
|
||||
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
165
api/tests/unit_tests/controllers/web/test_app.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Unit tests for controllers.web.app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
|
||||
from controllers.web.error import AppUnavailableError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppParameterApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppParameterApi:
|
||||
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {"opening_statement": "Hello"}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
|
||||
result = AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
|
||||
|
||||
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
|
||||
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=config)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
|
||||
|
||||
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppMeta
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppMeta:
|
||||
@patch("controllers.web.app.AppService")
|
||||
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
|
||||
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context("/meta"):
|
||||
result = AppMeta().get(app_model, SimpleNamespace())
|
||||
|
||||
assert result == {"tool_icons": {}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppAccessMode
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppAccessMode:
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "public"}
|
||||
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_access_mode_with_app_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="internal")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "internal"}
|
||||
mock_access.assert_called_once_with("app-1")
|
||||
|
||||
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_resolves_app_code_to_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="external")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appCode=code1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
mock_resolve.assert_called_once_with("code1")
|
||||
mock_access.assert_called_once_with("resolved-id")
|
||||
assert result == {"accessMode": "external"}
|
||||
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode"):
|
||||
with pytest.raises(ValueError, match="appId or appCode"):
|
||||
AppAccessMode().get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppWebAuthPermission
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppWebAuthPermission:
|
||||
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
|
||||
result = AppWebAuthPermission().get()
|
||||
|
||||
assert result == {"result": True}
|
||||
|
||||
def test_raises_when_missing_app_id(self, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(ValueError, match="appId"):
|
||||
AppWebAuthPermission().get()
|
||||
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
135
api/tests/unit_tests/controllers/web/test_audio.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Unit tests for controllers.web.audio endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.audio import AudioApi, TextApi
|
||||
from controllers.web.error import (
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioApi (audio-to-text)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAudioApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
|
||||
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result = AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == {"text": "hello"}
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
|
||||
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b""), "empty.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
|
||||
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"big"), "big.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
|
||||
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(UnsupportedAudioTypeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderNotSupportSpeechToTextServiceError(),
|
||||
)
|
||||
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError(description="no token"),
|
||||
)
|
||||
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
|
||||
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
|
||||
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TextApi (text-to-audio)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTextApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello", "voice": "alloy"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
result = TextApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == "audio-bytes"
|
||||
mock_tts.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_tts",
|
||||
side_effect=InvokeError(description="invoke failed"),
|
||||
)
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
TextApi().post(_app_model(), _end_user())
|
||||
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
161
api/tests/unit_tests/controllers/web/test_completion.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Unit tests for controllers.web.completion endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "test"}
|
||||
mock_gen.return_value = "response-obj"
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
result = CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "hi"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "hi"}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
result = ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "reply"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=InvokeError(description="rate limit"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "x"}
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
183
api/tests/unit_tests/controllers/web/test_conversation.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from controllers.web.error import NotChatAppError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
name="Test",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
id=str(c_id),
|
||||
name="New Name",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_rename.return_value = conv
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
|
||||
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
|
||||
@patch(
|
||||
"controllers.web.conversation.ConversationService.rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
75
api/tests/unit_tests/controllers/web/test_error.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Unit tests for controllers.web.error HTTP exception classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
InvalidArgumentError,
|
||||
InvokeRateLimitError,
|
||||
NoAudioUploadedError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotFoundError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
WebAppAuthAccessDeniedError,
|
||||
WebAppAuthRequiredError,
|
||||
WebFormRateLimitExceededError,
|
||||
)
|
||||
|
||||
_ERROR_SPECS: list[tuple[type, str, int]] = [
|
||||
(AppUnavailableError, "app_unavailable", 400),
|
||||
(NotCompletionAppError, "not_completion_app", 400),
|
||||
(NotChatAppError, "not_chat_app", 400),
|
||||
(NotWorkflowAppError, "not_workflow_app", 400),
|
||||
(ConversationCompletedError, "conversation_completed", 400),
|
||||
(ProviderNotInitializeError, "provider_not_initialize", 400),
|
||||
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
|
||||
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
|
||||
(CompletionRequestError, "completion_request_error", 400),
|
||||
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
|
||||
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
|
||||
(NoAudioUploadedError, "no_audio_uploaded", 400),
|
||||
(AudioTooLargeError, "audio_too_large", 413),
|
||||
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
|
||||
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
|
||||
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
|
||||
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
|
||||
(InvokeRateLimitError, "rate_limit_error", 429),
|
||||
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
|
||||
(NotFoundError, "not_found", 404),
|
||||
(InvalidArgumentError, "invalid_param", 400),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls", "expected_code", "expected_status"),
|
||||
_ERROR_SPECS,
|
||||
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
|
||||
)
|
||||
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
|
||||
"""Each error class exposes the correct error_code and HTTP status code."""
|
||||
assert cls.error_code == expected_code
|
||||
assert cls.code == expected_status
|
||||
|
||||
|
||||
def test_error_classes_have_description() -> None:
|
||||
"""Every error class has a description (string or None for generic errors)."""
|
||||
# NotFoundError and InvalidArgumentError use None description by design
|
||||
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
|
||||
for cls, _, _ in _ERROR_SPECS:
|
||||
if cls in _NO_DESCRIPTION:
|
||||
continue
|
||||
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
|
||||
assert len(cls.description) > 0, f"{cls.__name__} has empty description"
|
||||
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
38
api/tests/unit_tests/controllers/web/test_feature.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for controllers.web.feature endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.feature import SystemFeatureApi
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
with app.test_request_context("/system-features"):
|
||||
result = SystemFeatureApi().get()
|
||||
|
||||
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.assert_called_once()
|
||||
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
# Verify it's a bare Resource, not WebApiResource
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
assert issubclass(SystemFeatureApi, Resource)
|
||||
assert not issubclass(SystemFeatureApi, WebApiResource)
|
||||
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
89
api/tests/unit_tests/controllers/web/test_files.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Unit tests for controllers.web.files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.web.files import FileApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
class TestFileApi:
|
||||
def test_no_file_uploaded(self, app: Flask) -> None:
|
||||
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_too_many_files(self, app: Flask) -> None:
|
||||
data = {
|
||||
"file": (BytesIO(b"a"), "a.txt"),
|
||||
"file2": (BytesIO(b"b"), "b.txt"),
|
||||
}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
# Now has "file" key but len(request.files) > 1
|
||||
with pytest.raises(TooManyFilesError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_filename_missing(self, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"content"), "")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="test.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result, status = FileApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file-1"
|
||||
assert result["name"] == "test.txt"
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
import services.errors.file
|
||||
|
||||
mock_db.engine = "engine"
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
description="max 10MB"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big"), "big.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
156
api/tests/unit_tests/controllers/web/test_message_endpoints.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.web.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageFeedbackApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageFeedbackApi:
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "like", "content": "great"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": None}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "dislike"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageMoreLikeThisApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
|
||||
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
|
||||
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"answer": "similar"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MoreLikeThisDisabledError(),
|
||||
)
|
||||
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageSuggestedQuestionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result["data"] == ["What about X?", "Tell me more about Y."]
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotFound, match="Message not found"):
|
||||
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
103
api/tests/unit_tests/controllers/web/test_passport.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportService,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_none() -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_enterprise_webapp_user_id("token")
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
|
||||
assert decode_enterprise_webapp_user_id("token") == decoded
|
||||
|
||||
|
||||
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
|
||||
|
||||
decoded = {"auth_type": "public"}
|
||||
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
|
||||
assert result == "resp"
|
||||
|
||||
|
||||
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
|
||||
|
||||
|
||||
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
if _scalar_side_effect.calls == 1:
|
||||
return site
|
||||
if _scalar_side_effect.calls == 2:
|
||||
return app_model
|
||||
return None
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
|
||||
|
||||
|
||||
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
counts = [1, 0]
|
||||
|
||||
def _scalar(*_args, **_kwargs):
|
||||
return counts.pop(0)
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
session_id = generate_session_id()
|
||||
assert session_id
|
||||
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
423
api/tests/unit_tests/controllers/web/test_pydantic_models.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Unit tests for Pydantic models defined in controllers.web modules.
|
||||
|
||||
Covers validation logic, field defaults, constraints, and custom validators
|
||||
for all ~15 Pydantic models across the web controller layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.app import AppAccessModeQuery
|
||||
|
||||
|
||||
class TestAppAccessModeQuery:
|
||||
def test_alias_resolution(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
|
||||
assert q.app_id == "abc"
|
||||
assert q.app_code == "xyz"
|
||||
|
||||
def test_defaults_to_none(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({})
|
||||
assert q.app_id is None
|
||||
assert q.app_code is None
|
||||
|
||||
def test_accepts_snake_case(self) -> None:
|
||||
q = AppAccessModeQuery(app_id="id1", app_code="code1")
|
||||
assert q.app_id == "id1"
|
||||
assert q.app_code == "code1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audio.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.audio import TextToAudioPayload
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = TextToAudioPayload.model_validate({})
|
||||
assert p.message_id is None
|
||||
assert p.voice is None
|
||||
assert p.text is None
|
||||
assert p.streaming is None
|
||||
|
||||
def test_valid_uuid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = TextToAudioPayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_none_message_id_passthrough(self) -> None:
|
||||
p = TextToAudioPayload(message_id=None)
|
||||
assert p.message_id is None
|
||||
|
||||
def test_invalid_uuid_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
TextToAudioPayload(message_id="not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# completion.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
|
||||
|
||||
class TestCompletionMessagePayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = CompletionMessagePayload(inputs={})
|
||||
assert p.query == ""
|
||||
assert p.files is None
|
||||
assert p.response_mode is None
|
||||
assert p.retriever_from == "web_app"
|
||||
|
||||
def test_accepts_full_payload(self) -> None:
|
||||
p = CompletionMessagePayload(
|
||||
inputs={"key": "val"},
|
||||
query="test",
|
||||
files=[{"id": "f1"}],
|
||||
response_mode="streaming",
|
||||
)
|
||||
assert p.response_mode == "streaming"
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_invalid_response_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CompletionMessagePayload(inputs={}, response_mode="invalid")
|
||||
|
||||
|
||||
class TestChatMessagePayload:
|
||||
def test_valid_uuid_fields(self) -> None:
|
||||
cid = str(uuid4())
|
||||
pid = str(uuid4())
|
||||
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
|
||||
assert p.conversation_id == cid
|
||||
assert p.parent_message_id == pid
|
||||
|
||||
def test_none_uuid_fields(self) -> None:
|
||||
p = ChatMessagePayload(inputs={}, query="hi")
|
||||
assert p.conversation_id is None
|
||||
assert p.parent_message_id is None
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
|
||||
|
||||
def test_invalid_parent_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
|
||||
|
||||
def test_query_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessagePayload(inputs={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conversation.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = ConversationListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
assert q.pinned is None
|
||||
assert q.sort_by == "-updated_at"
|
||||
|
||||
def test_limit_lower_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_limit_upper_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
def test_limit_boundaries_valid(self) -> None:
|
||||
assert ConversationListQuery(limit=1).limit == 1
|
||||
assert ConversationListQuery(limit=100).limit == 100
|
||||
|
||||
def test_valid_sort_by_options(self) -> None:
|
||||
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
|
||||
assert ConversationListQuery(sort_by=opt).sort_by == opt
|
||||
|
||||
def test_invalid_sort_by(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(sort_by="invalid")
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
assert ConversationListQuery(last_id=uid).last_id == uid
|
||||
|
||||
def test_invalid_last_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ConversationListQuery(last_id="not-uuid")
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
def test_auto_generate_true_no_name_required(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=True)
|
||||
assert p.name is None
|
||||
|
||||
def test_auto_generate_false_requires_name(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
|
||||
def test_auto_generate_false_blank_name_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False, name=" ")
|
||||
|
||||
def test_auto_generate_false_with_valid_name(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
|
||||
assert p.name == "My Chat"
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
p = ConversationRenamePayload(name="test")
|
||||
assert p.auto_generate is False
|
||||
assert p.name == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
def test_valid_query(self) -> None:
|
||||
cid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid)
|
||||
assert q.conversation_id == cid
|
||||
assert q.first_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id="bad")
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=101)
|
||||
|
||||
def test_valid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
fid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid, first_id=fid)
|
||||
assert q.first_id == fid
|
||||
|
||||
def test_invalid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id=cid, first_id="invalid")
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = MessageFeedbackPayload()
|
||||
assert p.rating is None
|
||||
assert p.content is None
|
||||
|
||||
def test_valid_ratings(self) -> None:
|
||||
assert MessageFeedbackPayload(rating="like").rating == "like"
|
||||
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
|
||||
|
||||
def test_invalid_rating(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageFeedbackPayload(rating="neutral")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisQuery:
|
||||
def test_valid_modes(self) -> None:
|
||||
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
|
||||
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
|
||||
|
||||
def test_invalid_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery(response_mode="invalid")
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remote_files.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.remote_files import RemoteFileUploadPayload
|
||||
|
||||
|
||||
class TestRemoteFileUploadPayload:
|
||||
def test_valid_url(self) -> None:
|
||||
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
|
||||
assert str(p.url) == "https://example.com/file.pdf"
|
||||
|
||||
def test_invalid_url(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload(url="not-a-url")
|
||||
|
||||
def test_url_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# saved_message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
|
||||
|
||||
class TestSavedMessageListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = SavedMessageListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=101)
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
q = SavedMessageListQuery(last_id=uid)
|
||||
assert q.last_id == uid
|
||||
|
||||
def test_empty_last_id(self) -> None:
|
||||
q = SavedMessageListQuery(last_id="")
|
||||
assert q.last_id == ""
|
||||
|
||||
|
||||
class TestSavedMessageCreatePayload:
|
||||
def test_valid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = SavedMessageCreatePayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageCreatePayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# workflow.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.workflow import WorkflowRunPayload
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={})
|
||||
assert p.inputs == {}
|
||||
assert p.files is None
|
||||
|
||||
def test_with_files(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_inputs_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowRunPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# forgot_password.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestForgotPasswordSendPayload:
|
||||
def test_valid_email(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="user@example.com")
|
||||
assert p.email == "user@example.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
ForgotPasswordSendPayload(email="not-an-email")
|
||||
|
||||
def test_language_optional(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
|
||||
class TestForgotPasswordCheckPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.email == "a@b.com"
|
||||
assert p.code == "1234"
|
||||
assert p.token == "tok"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
|
||||
|
||||
|
||||
class TestForgotPasswordResetPayload:
|
||||
def test_valid_passwords(self) -> None:
|
||||
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
|
||||
assert p.new_password == "Valid1234"
|
||||
|
||||
def test_weak_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
|
||||
|
||||
def test_letters_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
|
||||
|
||||
def test_digits_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# login.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
|
||||
|
||||
|
||||
class TestLoginPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = LoginPayload(email="a@b.com", password="Valid1234")
|
||||
assert p.email == "a@b.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
LoginPayload(email="bad", password="Valid1234")
|
||||
|
||||
def test_weak_password(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
LoginPayload(email="a@b.com", password="weak")
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
def test_with_language(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
|
||||
assert p.language == "zh-Hans"
|
||||
|
||||
|
||||
class TestEmailCodeLoginVerifyPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.code == "1234"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")
|
||||
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
147
api/tests/unit_tests/controllers/web/test_remote_files.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Unit tests for controllers.web.remote_files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
|
||||
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileInfoApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileInfoApi:
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
|
||||
mock_proxy.head.return_value = mock_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
|
||||
|
||||
assert result["file_type"] == "application/pdf"
|
||||
assert result["file_length"] == 1024
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 405 # Method not allowed
|
||||
get_resp = MagicMock()
|
||||
get_resp.status_code = 200
|
||||
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
|
||||
get_resp.raise_for_status = MagicMock()
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
|
||||
|
||||
assert result["file_type"] == "text/plain"
|
||||
mock_proxy.get.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileUploadApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileUploadApi:
|
||||
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
|
||||
@patch("controllers.web.remote_files.FileService")
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
@patch("controllers.web.remote_files.db")
|
||||
def test_upload_success(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_file_svc_cls: MagicMock,
|
||||
mock_signed: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_ns.payload = {"url": "https://example.com/file.pdf"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
head_resp.content = b"pdf-content"
|
||||
head_resp.request.method = "HEAD"
|
||||
mock_proxy.head.return_value = head_resp
|
||||
get_resp = MagicMock()
|
||||
get_resp.content = b"pdf-content"
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
|
||||
)
|
||||
mock_file_svc_cls.is_file_size_within_limit.return_value = True
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="f-1",
|
||||
name="file.pdf",
|
||||
size=100,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "f-1"
|
||||
|
||||
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_file_too_large(
|
||||
self,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_size_check: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_ns.payload = {"url": "https://example.com/big.zip"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
|
||||
)
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
import httpx
|
||||
|
||||
mock_ns.payload = {"url": "https://example.com/bad"}
|
||||
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(RemoteFileUploadError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
97
api/tests/unit_tests/controllers/web/test_saved_message.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Unit tests for controllers.web.saved_message endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (GET)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiGet:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().get(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
|
||||
|
||||
with app.test_request_context("/saved-messages?limit=20"):
|
||||
result = SavedMessageListApi().get(_completion_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (POST)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiPost:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save")
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
msg_id = str(uuid4())
|
||||
mock_ns.payload = {"message_id": msg_id}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
result = SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"message_id": str(uuid4())}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageApi (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageApi:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
126
api/tests/unit_tests/controllers/web/test_site.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Unit tests for controllers.web.site endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
|
||||
|
||||
def _tenant(*, status: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=status,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
|
||||
)
|
||||
|
||||
|
||||
def _site() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
@patch("controllers.web.site.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
# marshal_with serializes AppSiteInfo to a dict
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteInfo:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = _tenant()
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
assert info.end_user_id == "eu-1"
|
||||
assert info.enable_site is True
|
||||
assert info.plan == "basic"
|
||||
assert info.can_replace_logo is False
|
||||
assert info.model_config is None
|
||||
|
||||
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
)
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
assert info.custom_config["remove_webapp_brand"] is True
|
||||
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
|
||||
@@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
import services.errors.account
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
@@ -89,3 +90,114 @@ class TestEmailCodeLoginApi:
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
|
||||
@patch("controllers.web.login.WebAppAuthService.authenticate")
|
||||
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
|
||||
mock_auth.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.get_json()["data"]["access_token"] == "access-tok"
|
||||
mock_auth.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountLoginError(),
|
||||
)
|
||||
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountPasswordError(),
|
||||
)
|
||||
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
|
||||
class TestLoginStatusApi:
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
|
||||
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/login/status"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token")
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_public_app_user_logged_in(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_decode.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is True
|
||||
assert result["app_logged_in"] is True
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_private_app_passport_fails(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
|
||||
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/logout", method="POST"):
|
||||
response = LogoutApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
192
api/tests/unit_tests/controllers/web/test_web_passport.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportResource,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_enterprise_webapp_user_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeEnterpriseWebappUserId:
|
||||
def test_none_token_returns_none(self) -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "webapp_login_token",
|
||||
"user_id": "u1",
|
||||
}
|
||||
result = decode_enterprise_webapp_user_id("valid-jwt")
|
||||
assert result["user_id"] == "u1"
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "other_source",
|
||||
}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("bad-jwt")
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("no-source-jwt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_session_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSessionId:
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = 0
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert len(sid) == 36 # UUID format
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
|
||||
# First call returns count=1 (collision), second returns 0
|
||||
mock_db.session.scalar.side_effect = [1, 0]
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert mock_db.session.scalar.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exchange_token_for_existing_web_user
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestExchangeTokenForExistingWebUser:
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="external"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="internal"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = None
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PassportResource.get
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPassportResource:
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
with app.test_request_context("/passport"):
|
||||
with pytest.raises(Unauthorized, match="X-App-Code"):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_creates_new_end_user_when_no_user_id(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_gen_session: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
mock_passport_cls.return_value.issue.return_value = "issued-token"
|
||||
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "issued-token"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_reuses_existing_end_user_when_user_id_provided(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
|
||||
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
|
||||
mock_passport_cls.return_value.issue.return_value = "reused-token"
|
||||
|
||||
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "reused-token"
|
||||
# Should not create a new end user
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_db.session.scalar.return_value = None
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
|
||||
mock_db.session.scalar.side_effect = [site, disabled_app]
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
95
api/tests/unit_tests/controllers/web/test_workflow.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Unit tests for controllers.web.workflow endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import (
|
||||
NotWorkflowAppError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="workflow")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowRunApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowRunApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowRunApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
|
||||
@patch("controllers.web.workflow.AppGenerateService.generate")
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {"key": "val"}}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
result = WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTaskStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
|
||||
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
|
||||
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_legacy.assert_called_once_with("task-1")
|
||||
mock_graph.assert_called_once_with("task-1")
|
||||
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
127
api/tests/unit_tests/controllers/web/test_workflow_events.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Unit tests for controllers.web.workflow_events endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowEventsApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowEventsApi:
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="other-app",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_created_by_end_user(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="other-user",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_finished_run_returns_sse_response(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
|
||||
) -> None:
|
||||
from datetime import datetime
|
||||
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
finish_response = MagicMock()
|
||||
finish_response.model_dump.return_value = {"task_id": "run-1"}
|
||||
finish_response.event.value = "workflow_finished"
|
||||
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
393
api/tests/unit_tests/controllers/web/test_wraps.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from controllers.web.wraps import (
|
||||
_validate_user_accessibility,
|
||||
_validate_webapp_token,
|
||||
decode_jwt_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "webapp"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_non_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "other"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_no_source(self) -> None:
|
||||
decoded = {}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
app_web_auth_enabled=False,
|
||||
system_webapp_auth_enabled=False,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_user_id_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=SimpleNamespace(access_mode="internal"),
|
||||
)
|
||||
|
||||
def test_missing_webapp_settings_raises(self) -> None:
|
||||
decoded = {"user_id": "u1"}
|
||||
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_auth_type_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "granted_at": 1}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
def test_missing_granted_at_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_internal_auth_type_checks_workspace_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
|
||||
) -> None:
|
||||
mock_workspace_sso.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_passes_when_granted_after_sso_update(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
def test_permission_check_denies_unauthorized_user(
|
||||
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
|
||||
) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
|
||||
settings = SimpleNamespace(access_mode="internal")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@@ -22,6 +22,7 @@ from dify_graph.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
from dify_graph.nodes.template_transform import TemplateTransformNode
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@@ -65,6 +66,8 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@@ -112,7 +112,6 @@ class TestKnowledgeRetrievalNode:
|
||||
# Assert
|
||||
assert node.id == node_id
|
||||
assert node._rag_retrieval == mock_rag_retrieval
|
||||
assert node._llm_file_saver is not None
|
||||
|
||||
def test_run_with_no_query_or_attachment(
|
||||
self,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
@@ -44,7 +44,6 @@ class TestFileSaverImpl:
|
||||
)
|
||||
mock_tool_file.id = _gen_id()
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
@@ -53,11 +52,12 @@ class TestFileSaverImpl:
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
http_client = MagicMock()
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
@@ -87,16 +87,18 @@ class TestFileSaverImpl:
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
http_client = MagicMock()
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
http_client=http_client,
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
http_client.get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -112,8 +114,10 @@ class TestFileSaverImpl:
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
http_client = MagicMock()
|
||||
http_client.get.return_value = mock_response
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client)
|
||||
mock_tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -111,6 +111,7 @@ def llm_node(
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
@@ -120,6 +121,7 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
|
||||
@@ -632,6 +634,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
}
|
||||
http_client = mock.MagicMock()
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
@@ -641,6 +644,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 6. Close Handling ───────────────────────────────────────────────────
|
||||
describe('Close handling', () => {
|
||||
it('should call onCancel when pressing ESC key', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
// ahooks useKeyPress listens on document for keydown events
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
code: 'Escape',
|
||||
keyCode: 27,
|
||||
bubbles: true,
|
||||
}))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
|
||||
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
|
||||
describe('Pricing page URL', () => {
|
||||
it('should render pricing link with correct URL', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
@@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
||||
isShow={isShowDeleteConfirm}
|
||||
onClose={() => setIsShowDeleteConfirm(false)}
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-3 text-text-primary">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<div className="mb-3 text-text-primary title-2xl-semi-bold">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<p className="mb-8 text-text-secondary">{t('avatar.deleteDescription', { ns: 'common' })}</p>
|
||||
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
|
||||
@@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
</div>
|
||||
{step === STEP.start && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-medium text-text-warning">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-warning body-md-medium">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content1"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyOrigin && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content2"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -278,25 +278,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToOriginEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToOriginEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.newEmail && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
<div className="text-text-secondary body-md-regular">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.emailPlaceholder', { ns: 'common' })}
|
||||
@@ -305,10 +305,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
destructive={newEmailExited || unAvailableEmail}
|
||||
/>
|
||||
{newEmailExited && (
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
{unAvailableEmail && (
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-3 space-y-2">
|
||||
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyNew && (
|
||||
<>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content4"
|
||||
ns="common"
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
values={{ email: mail }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -368,13 +368,13 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToNewEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToNewEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -138,7 +138,7 @@ export default function AccountPage() {
|
||||
imageUrl={icon_url}
|
||||
/>
|
||||
</div>
|
||||
<div className="system-sm-medium mt-[3px] text-text-secondary">{item.name}</div>
|
||||
<div className="mt-[3px] text-text-secondary system-sm-medium">{item.name}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -146,12 +146,12 @@ export default function AccountPage() {
|
||||
return (
|
||||
<>
|
||||
<div className="pb-3 pt-2">
|
||||
<h4 className="title-2xl-semi-bold text-text-primary">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
</div>
|
||||
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
|
||||
<div className="ml-4">
|
||||
<p className="system-xl-semibold text-text-primary">
|
||||
<p className="text-text-primary system-xl-semibold">
|
||||
{userProfile.name}
|
||||
{isEducationAccount && (
|
||||
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
|
||||
@@ -160,16 +160,16 @@ export default function AccountPage() {
|
||||
</PremiumBadge>
|
||||
)}
|
||||
</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{userProfile.email}</p>
|
||||
<p className="text-text-tertiary system-xs-regular">{userProfile.email}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<span className="pl-1">{userProfile.name}</span>
|
||||
</div>
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={handleEditName}>
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={handleEditName}>
|
||||
{t('operation.edit', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -177,11 +177,11 @@ export default function AccountPage() {
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.email', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<span className="pl-1">{userProfile.email}</span>
|
||||
</div>
|
||||
{systemFeatures.enable_change_email && (
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={() => setShowUpdateEmail(true)}>
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={() => setShowUpdateEmail(true)}>
|
||||
{t('operation.change', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
@@ -191,8 +191,8 @@ export default function AccountPage() {
|
||||
systemFeatures.enable_email_password_login && (
|
||||
<div className="mb-8 flex justify-between gap-2">
|
||||
<div>
|
||||
<div className="system-sm-semibold mb-1 text-text-secondary">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="body-xs-regular mb-2 text-text-tertiary">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
<div className="mb-1 text-text-secondary system-sm-semibold">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="mb-2 text-text-tertiary body-xs-regular">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<Button onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</Button>
|
||||
</div>
|
||||
@@ -219,7 +219,7 @@ export default function AccountPage() {
|
||||
onClose={() => setEditNameModalVisible(false)}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="mt-2"
|
||||
@@ -249,7 +249,7 @@ export default function AccountPage() {
|
||||
}}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
{userProfile.is_password_set && (
|
||||
<>
|
||||
<div className={titleClassName}>{t('account.currentPassword', { ns: 'common' })}</div>
|
||||
@@ -272,7 +272,7 @@ export default function AccountPage() {
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">
|
||||
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="relative mt-2">
|
||||
@@ -291,7 +291,7 @@ export default function AccountPage() {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="relative mt-2">
|
||||
<Input
|
||||
type={showConfirmPassword ? 'text' : 'password'}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user