mirror of
https://github.com/langgenius/dify.git
synced 2026-03-09 17:25:10 +00:00
Compare commits
5 Commits
deploy/dev
...
feat/enter
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b15281b2a0 | ||
|
|
d4ac40d0f1 | ||
|
|
2b0a3c615f | ||
|
|
9ddd2a8d49 | ||
|
|
f711e4c385 |
33
.github/actions/setup-web/action.yml
vendored
33
.github/actions/setup-web/action.yml
vendored
@@ -1,33 +0,0 @@
|
||||
name: Setup Web Environment
|
||||
description: Setup pnpm, Node.js, and install web dependencies.
|
||||
|
||||
inputs:
|
||||
node-version:
|
||||
description: Node.js version to use
|
||||
required: false
|
||||
default: "22"
|
||||
install-dependencies:
|
||||
description: Whether to install web dependencies after setting up Node.js
|
||||
required: false
|
||||
default: "true"
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
with:
|
||||
node-version: ${{ inputs.node-version }}
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
if: ${{ inputs.install-dependencies == 'true' }}
|
||||
shell: bash
|
||||
run: pnpm --dir web install --frozen-lockfile
|
||||
16
.github/dependabot.yml
vendored
16
.github/dependabot.yml
vendored
@@ -24,18 +24,6 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
- dependency-name: "ky"
|
||||
- dependency-name: "tailwind-merge"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "tailwindcss"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-markdown"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-syntax-highlighter"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-window"
|
||||
update-types: ["version-update:semver-major"]
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
@@ -45,9 +33,6 @@ updates:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
eslint-group:
|
||||
patterns:
|
||||
- "*eslint*"
|
||||
npm-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
@@ -56,4 +41,3 @@ updates:
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
- "*eslint*"
|
||||
|
||||
6
.github/workflows/api-tests.yml
vendored
6
.github/workflows/api-tests.yml
vendored
@@ -22,12 +22,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
- name: Set up Sandbox
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@v2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
20
.github/workflows/autofix.yml
vendored
20
.github/workflows/autofix.yml
vendored
@@ -12,22 +12,22 @@ jobs:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Check Docker Compose inputs
|
||||
id: docker-compose-changes
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
docker/generate_docker_compose
|
||||
docker/.env.example
|
||||
docker/docker-compose-template.yaml
|
||||
docker/docker-compose.yaml
|
||||
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
|
||||
- uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
@@ -84,14 +84,4 @@ jobs:
|
||||
run: |
|
||||
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: "24"
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm eslint --concurrency=2 --prune-suppressions
|
||||
|
||||
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3
|
||||
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
|
||||
|
||||
18
.github/workflows/build-push.yml
vendored
18
.github/workflows/build-push.yml
vendored
@@ -53,26 +53,26 @@ jobs:
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
|
||||
- name: Build Docker image
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
platforms: ${{ matrix.platform }}
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
touch "/tmp/digests/${sanitized_digest}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
@@ -113,21 +113,21 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env[matrix.image_name_env] }}
|
||||
tags: |
|
||||
|
||||
12
.github/workflows/db-migration-test.yml
vendored
12
.github/workflows/db-migration-test.yml
vendored
@@ -13,13 +13,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
cp middleware.env.example middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
@@ -63,13 +63,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
|
||||
|
||||
- name: Set up Middlewares
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.middleware.yaml
|
||||
|
||||
2
.github/workflows/deploy-agent-dev.yml
vendored
2
.github/workflows/deploy-agent-dev.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/agent-dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
uses: appleboy/ssh-action@v1
|
||||
with:
|
||||
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-dev.yml
vendored
2
.github/workflows/deploy-dev.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'deploy/dev'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
uses: appleboy/ssh-action@v1
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
2
.github/workflows/deploy-hitl.yml
vendored
2
.github/workflows/deploy-hitl.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
|
||||
uses: appleboy/ssh-action@v1
|
||||
with:
|
||||
host: ${{ secrets.HITL_SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
|
||||
6
.github/workflows/docker-build.yml
vendored
6
.github/workflows/docker-build.yml
vendored
@@ -32,13 +32,13 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker Image
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
|
||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -9,6 +9,6 @@ jobs:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
5
.github/workflows/main-ci.yml
vendored
5
.github/workflows/main-ci.yml
vendored
@@ -27,8 +27,8 @@ jobs:
|
||||
vdb-changed: ${{ steps.changes.outputs.vdb }}
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
- uses: actions/checkout@v6
|
||||
- uses: dorny/paths-filter@v3
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
@@ -39,7 +39,6 @@ jobs:
|
||||
web:
|
||||
- 'web/**'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
vdb:
|
||||
- 'api/core/rag/datasource/**'
|
||||
- 'docker/**'
|
||||
|
||||
4
.github/workflows/pyrefly-diff-comment.yml
vendored
4
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
|
||||
steps:
|
||||
- name: Download pyrefly diff artifact
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
run: unzip -o pyrefly_diff.zip
|
||||
|
||||
- name: Post comment
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
8
.github/workflows/pyrefly-diff.yml
vendored
8
.github/workflows/pyrefly-diff.yml
vendored
@@ -17,12 +17,12 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
|
||||
- name: Upload pyrefly diff
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pyrefly_diff
|
||||
path: |
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
|
||||
2
.github/workflows/semantic-pull-request.yml
vendored
2
.github/workflows/semantic-pull-request.yml
vendored
@@ -16,6 +16,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check title
|
||||
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
|
||||
uses: amannn/action-semantic-pull-request@v6.1.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
days-before-issue-stale: 15
|
||||
days-before-issue-close: 3
|
||||
|
||||
36
.github/workflows/style.yml
vendored
36
.github/workflows/style.yml
vendored
@@ -19,13 +19,13 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
api/**
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
@@ -67,22 +67,36 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
@@ -120,14 +134,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
**.sh
|
||||
@@ -138,7 +152,7 @@ jobs:
|
||||
.editorconfig
|
||||
|
||||
- name: Super-linter
|
||||
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
|
||||
uses: super-linter/super-linter/slim@v8
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
env:
|
||||
BASH_SEVERITY: warning
|
||||
|
||||
4
.github/workflows/tool-test-sdks.yaml
vendored
4
.github/workflows/tool-test-sdks.yaml
vendored
@@ -21,12 +21,12 @@ jobs:
|
||||
working-directory: sdks/nodejs-client
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Use Node.js
|
||||
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: ''
|
||||
|
||||
18
.github/workflows/translate-i18n-claude.yml
vendored
18
.github/workflows/translate-i18n-claude.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -48,10 +48,18 @@ jobs:
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
install-dependencies: "false"
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
@@ -122,7 +130,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
4
.github/workflows/trigger-i18n-sync.yml
vendored
4
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
event-type: i18n-sync
|
||||
|
||||
8
.github/workflows/vdb-tests.yml
vendored
8
.github/workflows/vdb-tests.yml
vendored
@@ -19,19 +19,19 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
uses: endersonmenezes/free-disk-space@v3
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
|
||||
68
.github/workflows/web-tests.yml
vendored
68
.github/workflows/web-tests.yml
vendored
@@ -26,19 +26,32 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
@@ -57,15 +70,28 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
@@ -393,7 +419,7 @@ jobs:
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
@@ -409,22 +435,36 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
uses: tj-actions/changed-files@v47
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup NodeJS
|
||||
uses: actions/setup-node@v6
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
@@ -44,6 +44,7 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
|
||||
@@ -113,6 +114,7 @@ ignore_imports =
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.tool.tool_node -> models
|
||||
dify_graph.nodes.agent.agent_node -> models.model
|
||||
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
@@ -133,6 +135,7 @@ ignore_imports =
|
||||
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.errors
|
||||
dify_graph.nodes.agent.agent_node -> extensions.ext_database
|
||||
dify_graph.nodes.llm.file_saver -> extensions.ext_database
|
||||
dify_graph.nodes.llm.node -> extensions.ext_database
|
||||
dify_graph.nodes.tool.tool_node -> extensions.ext_database
|
||||
dify_graph.nodes.agent.agent_node -> models
|
||||
|
||||
@@ -18,3 +18,7 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
|
||||
)
|
||||
|
||||
@@ -39,7 +39,6 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -185,7 +184,6 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -8,7 +6,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -18,7 +16,6 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -280,168 +277,3 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
@@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
|
||||
console_ns.abort(
|
||||
400,
|
||||
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
|
||||
custom="max_keys_exceeded",
|
||||
code="max_keys_exceeded",
|
||||
)
|
||||
|
||||
key = ApiToken.generate_api_key(self.token_prefix, 24)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"The notification is NOT marked as seen here; call POST /notification/dismiss "
|
||||
"when the user explicitly closes the modal."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = current_user.interface_language or _FALLBACK_LANG
|
||||
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
class NotificationDismissApi(Resource):
|
||||
@console_ns.doc("dismiss_notification")
|
||||
@console_ns.doc(
|
||||
description="Mark a notification as dismissed for the current user.",
|
||||
responses={200: "Success", 401: "Unauthorized"},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
account_id=str(current_user.id),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@@ -10,6 +10,7 @@ from controllers.common.file_response import enforce_download_for_html
|
||||
from controllers.files import files_ns
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -56,7 +57,7 @@ class ToolFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file_manager = ToolFileManager()
|
||||
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
raise NotCompletionAppError()
|
||||
|
||||
message_id = str(message_id)
|
||||
|
||||
|
||||
@@ -158,6 +158,7 @@ class PluginEntity(PluginInstallation):
|
||||
name: str
|
||||
installation_id: str
|
||||
version: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_plugin_id(self):
|
||||
|
||||
@@ -10,18 +10,28 @@ from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db as global_db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
@@ -79,7 +89,7 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -122,7 +132,7 @@ class ToolFileManager:
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
with session_factory.create_session() as session:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -147,7 +157,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
@@ -171,7 +181,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.where(
|
||||
@@ -215,7 +225,7 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
with session_factory.create_session() as session:
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.where(
|
||||
|
||||
@@ -250,7 +250,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@@ -293,7 +292,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
http_client=self._http_request_http_client,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
|
||||
@@ -14,6 +14,7 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -46,6 +47,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
@@ -53,6 +56,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -64,6 +69,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from extensions.ext_database import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
@@ -56,20 +59,30 @@ class LLMFileSaver(tp.Protocol):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
self._http_client = http_client
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager()
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = self._http_client.get(url)
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
|
||||
@@ -64,7 +64,6 @@ from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
@@ -128,7 +127,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -151,7 +149,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from dify_graph.nodes.llm import (
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@@ -69,7 +68,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -92,7 +90,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
http_client=http_client,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
|
||||
@@ -21,10 +21,6 @@ celery_redis = Redis(
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
|
||||
socket_timeout=5,
|
||||
socket_connect_timeout=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,7 +3,6 @@ import math
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from celery import group
|
||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -86,25 +85,20 @@ def trigger_provider_refresh() -> None:
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
if not any(acquired):
|
||||
continue
|
||||
|
||||
jobs = [
|
||||
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
|
||||
if is_locked
|
||||
]
|
||||
result = group(jobs).apply_async()
|
||||
enqueued = len(jobs)
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
result,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from celery import current_app, group, shared_task
|
||||
from celery import group, shared_task
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -29,27 +29,31 @@ def poll_workflow_schedules() -> None:
|
||||
with session_factory() as session:
|
||||
total_dispatched = 0
|
||||
|
||||
# Process in batches until we've handled all due schedules or hit the limit
|
||||
while True:
|
||||
due_schedules = _fetch_due_schedules(session)
|
||||
|
||||
if not due_schedules:
|
||||
break
|
||||
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
dispatched_count = _process_schedules(session, due_schedules, producer)
|
||||
total_dispatched += dispatched_count
|
||||
dispatched_count = _process_schedules(session, due_schedules)
|
||||
total_dispatched += dispatched_count
|
||||
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
logger.debug("Batch processed: %d dispatched", dispatched_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if (
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||
):
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
if total_dispatched > 0:
|
||||
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
|
||||
logger.info("Total processed: %d dispatched", total_dispatched)
|
||||
|
||||
|
||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
@@ -86,7 +90,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
return list(due_schedules)
|
||||
|
||||
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
|
||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||
if not schedules:
|
||||
return 0
|
||||
@@ -103,7 +107,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan],
|
||||
|
||||
if tasks_to_dispatch:
|
||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||
job.apply_async(producer=producer)
|
||||
job.apply_async()
|
||||
|
||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||
|
||||
|
||||
@@ -393,78 +393,3 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
calls will return should_show=false when frequency='once'.
|
||||
|
||||
Response shape (mirrors GetAccountNotificationReply):
|
||||
{
|
||||
"should_show": bool,
|
||||
"notification": { # present only when should_show=true
|
||||
"notification_id": str,
|
||||
"contents": { # lang -> LangContent
|
||||
"en": {"lang": "en", "title": ..., "subtitle": ..., "body": ..., "title_pic_url": ...},
|
||||
...
|
||||
},
|
||||
"frequency": "once" | "every_page_load"
|
||||
}
|
||||
}
|
||||
"""
|
||||
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
}
|
||||
if notification_id:
|
||||
payload["notification_id"] = notification_id
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/accounts",
|
||||
json={"account_ids": account_ids},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
|
||||
"""Mark a notification as dismissed for an account.
|
||||
|
||||
Returns {"success": bool}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/dismiss",
|
||||
json={"account_id": account_id},
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.base import EnterprisePluginManagerRequest
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
@@ -28,6 +29,11 @@ class CheckCredentialPolicyComplianceRequest(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class PreUninstallPluginRequest(BaseModel):
|
||||
tenant_id: str
|
||||
plugin_unique_identifier: str
|
||||
|
||||
|
||||
class CredentialPolicyViolationError(BaseServiceError):
|
||||
pass
|
||||
|
||||
@@ -55,3 +61,24 @@ class PluginManagerService:
|
||||
body.dify_credential_id,
|
||||
ret.get("result", False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def try_pre_uninstall_plugin(cls, body: PreUninstallPluginRequest):
|
||||
try:
|
||||
# the invocation must be synchronous.
|
||||
EnterprisePluginManagerRequest.send_request( # pyright: ignore[reportUnknownMemberType]
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(), # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType]
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"""
|
||||
failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s,
|
||||
this may cause plugin to be automatically garbage collected
|
||||
""",
|
||||
body.tenant_id,
|
||||
body.plugin_unique_identifier,
|
||||
)
|
||||
|
||||
@@ -32,6 +32,10 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
from services.errors.plugin import PluginInstallationForbiddenError
|
||||
from services.feature_service import FeatureService, PluginInstallationScope
|
||||
|
||||
@@ -519,6 +523,13 @@ class PluginService:
|
||||
if not plugin:
|
||||
return manager.uninstall(tenant_id, plugin_installation_id)
|
||||
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
PluginManagerService.try_pre_uninstall_plugin(
|
||||
PreUninstallPluginRequest(
|
||||
tenant_id=tenant_id,
|
||||
plugin_unique_identifier=plugin.plugin_unique_identifier,
|
||||
)
|
||||
)
|
||||
with Session(db.engine) as session, session.begin():
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Callable, Sequence
|
||||
|
||||
import click
|
||||
from celery import current_app, shared_task
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@@ -20,12 +19,6 @@ from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CeleryTaskLike(Protocol):
|
||||
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
"""
|
||||
@@ -186,8 +179,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
|
||||
|
||||
def _document_indexing_with_tenant_queue(
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
|
||||
) -> None:
|
||||
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
|
||||
):
|
||||
try:
|
||||
_document_indexing(dataset_id, document_ids)
|
||||
except Exception:
|
||||
@@ -208,20 +201,16 @@ def _document_indexing_with_tenant_queue(
|
||||
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
|
||||
|
||||
if next_tasks:
|
||||
with current_app.producer_or_acquire() as producer: # type: ignore
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.apply_async(
|
||||
kwargs={
|
||||
"tenant_id": document_task.tenant_id,
|
||||
"dataset_id": document_task.dataset_id,
|
||||
"document_ids": document_task.document_ids,
|
||||
},
|
||||
producer=producer,
|
||||
)
|
||||
|
||||
for next_task in next_tasks:
|
||||
document_task = DocumentTask(**next_task)
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
task_func.delay( # type: ignore
|
||||
tenant_id=document_task.tenant_id,
|
||||
dataset_id=document_task.dataset_id,
|
||||
document_ids=document_task.document_ids,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -3,13 +3,12 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import group, shared_task
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
@@ -28,11 +27,6 @@ from services.file_service import FileService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunked(iterable: Sequence, size: int):
|
||||
it = iter(iterable)
|
||||
return iter(lambda: list(islice(it, size)), [])
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
@@ -89,24 +83,16 @@ def rag_pipeline_run_task(
|
||||
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
|
||||
|
||||
if next_file_ids:
|
||||
for batch in chunked(next_file_ids, 100):
|
||||
jobs = []
|
||||
for next_file_id in batch:
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
|
||||
file_id = (
|
||||
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
|
||||
)
|
||||
|
||||
jobs.append(
|
||||
rag_pipeline_run_task.s(
|
||||
rag_pipeline_invoke_entities_file_id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
if jobs:
|
||||
group(jobs).apply_async()
|
||||
for next_file_id in next_file_ids:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
tenant_isolated_task_queue.set_task_waiting_time()
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
tenant_isolated_task_queue.delete_task_key()
|
||||
|
||||
@@ -11,7 +11,6 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@@ -75,7 +74,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
@@ -322,14 +322,11 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
# apply_async is used by implementation; assert it was called once with expected kwargs
|
||||
assert task_dispatch_spy.apply_async.call_count == 1
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
|
||||
assert call_kwargs == {
|
||||
"tenant_id": next_task["tenant_id"],
|
||||
"dataset_id": next_task["dataset_id"],
|
||||
"document_ids": next_task["document_ids"],
|
||||
}
|
||||
task_dispatch_spy.delay.assert_called_once_with(
|
||||
tenant_id=next_task["tenant_id"],
|
||||
dataset_id=next_task["dataset_id"],
|
||||
document_ids=next_task["document_ids"],
|
||||
)
|
||||
set_waiting_spy.assert_called_once()
|
||||
delete_key_spy.assert_not_called()
|
||||
|
||||
@@ -355,7 +352,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.apply_async.assert_not_called()
|
||||
task_dispatch_spy.delay.assert_not_called()
|
||||
delete_key_spy.assert_called_once()
|
||||
|
||||
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
|
||||
@@ -450,7 +447,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
task_dispatch_spy.apply_async.assert_called_once()
|
||||
task_dispatch_spy.delay.assert_called_once()
|
||||
|
||||
def test_sessions_close_on_successful_indexing(
|
||||
self,
|
||||
@@ -537,7 +534,7 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
|
||||
assert task_dispatch_spy.delay.call_count == concurrency_limit
|
||||
assert set_waiting_spy.call_count == concurrency_limit
|
||||
|
||||
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
|
||||
@@ -568,10 +565,9 @@ class TestDatasetIndexingTaskIntegration:
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
# Assert
|
||||
assert task_dispatch_spy.apply_async.call_count == 3
|
||||
assert task_dispatch_spy.delay.call_count == 3
|
||||
for index, expected_task in enumerate(ordered_tasks):
|
||||
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
|
||||
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
|
||||
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
|
||||
|
||||
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
|
||||
"""Skip limit checks when billing feature is disabled."""
|
||||
|
||||
@@ -762,12 +762,11 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify task function was called for each waiting task
|
||||
assert mock_task_func.apply_async.call_count == 1
|
||||
assert mock_task_func.delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for each call
|
||||
calls = mock_task_func.apply_async.call_args_list
|
||||
sent_kwargs = calls[0][1]["kwargs"]
|
||||
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
calls = mock_task_func.delay.call_args_list
|
||||
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (tasks were pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
|
||||
@@ -831,15 +830,11 @@ class TestDocumentIndexingTasks:
|
||||
assert updated_document.processing_started_at is not None
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
mock_task_func.delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the call
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant_id,
|
||||
"dataset_id": dataset_id,
|
||||
"document_ids": ["waiting-doc-1"],
|
||||
}
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -901,13 +896,9 @@ class TestDocumentIndexingTasks:
|
||||
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
|
||||
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_task_func.apply_async.assert_called_once()
|
||||
call = mock_task_func.apply_async.call_args
|
||||
assert call[1]["kwargs"] == {
|
||||
"tenant_id": tenant1_id,
|
||||
"dataset_id": dataset1_id,
|
||||
"document_ids": ["tenant1-doc-1"],
|
||||
}
|
||||
mock_task_func.delay.assert_called_once()
|
||||
call = mock_task_func.delay.call_args
|
||||
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@@ -388,10 +388,8 @@ class TestRagPipelineRunTasks:
|
||||
# Set the task key to indicate there are waiting tasks (legacy behavior)
|
||||
redis_client.set(legacy_task_key, 1, ex=60 * 60)
|
||||
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the priority task with new code but legacy queue data
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -400,14 +398,13 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed via group, pull 1 task a time by default
|
||||
assert mock_group.return_value.apply_async.called
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify that new code can process legacy queue entries
|
||||
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
|
||||
@@ -449,10 +446,8 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||
queue.push_tasks(waiting_file_ids)
|
||||
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -461,14 +456,13 @@ class TestRagPipelineRunTasks:
|
||||
mock_file_service["delete_file"].assert_called_once_with(file_id)
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting tasks were processed via group.apply_async
|
||||
assert mock_group.return_value.apply_async.called
|
||||
# Verify waiting tasks were processed, pull 1 task a time by default
|
||||
assert mock_delay.call_count == 1
|
||||
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue still has remaining tasks (only 1 was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -563,10 +557,8 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task (should not raise exception)
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
|
||||
@@ -577,13 +569,12 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify waiting task was still processed despite core processing error
|
||||
assert mock_group.return_value.apply_async.called
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
@@ -693,10 +684,8 @@ class TestRagPipelineRunTasks:
|
||||
queue1.push_tasks([waiting_file_id1])
|
||||
queue2.push_tasks([waiting_file_id2])
|
||||
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act: Execute the regular task for tenant1 only
|
||||
rag_pipeline_run_task(file_id1, tenant1.id)
|
||||
|
||||
@@ -705,12 +694,11 @@ class TestRagPipelineRunTasks:
|
||||
assert mock_file_service["delete_file"].call_count == 1
|
||||
assert mock_pipeline_generator.call_count == 1
|
||||
|
||||
# Verify only tenant1's waiting task was processed (via group)
|
||||
assert mock_group.return_value.apply_async.called
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert first_kwargs.get("tenant_id") == tenant1.id
|
||||
# Verify only tenant1's waiting task was processed
|
||||
mock_delay.assert_called_once()
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
|
||||
assert call_kwargs.get("tenant_id") == tenant1.id
|
||||
|
||||
# Verify tenant1's queue is empty
|
||||
remaining_tasks1 = queue1.pull_tasks(count=10)
|
||||
@@ -925,10 +913,8 @@ class TestRagPipelineRunTasks:
|
||||
waiting_file_id = str(uuid.uuid4())
|
||||
queue.push_tasks([waiting_file_id])
|
||||
|
||||
# Mock the Celery group scheduling used by the implementation
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
|
||||
mock_group.return_value.apply_async = MagicMock()
|
||||
|
||||
# Mock the task function calls
|
||||
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
|
||||
# Act & Assert: Execute the regular task (should raise Exception)
|
||||
with pytest.raises(Exception, match="File not found"):
|
||||
rag_pipeline_run_task(file_id, tenant.id)
|
||||
@@ -938,13 +924,12 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
# Verify waiting task was still processed despite file error
|
||||
assert mock_group.return_value.apply_async.called
|
||||
mock_delay.assert_called_once()
|
||||
|
||||
# Verify correct parameters for the first scheduled job signature
|
||||
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
|
||||
first_kwargs = jobs[0].kwargs if jobs else {}
|
||||
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert first_kwargs.get("tenant_id") == tenant.id
|
||||
# Verify correct parameters for the call
|
||||
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
|
||||
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
|
||||
assert call_kwargs.get("tenant_id") == tenant.id
|
||||
|
||||
# Verify queue is empty after processing (task was pulled)
|
||||
remaining_tasks = queue.pull_tasks(count=10)
|
||||
|
||||
@@ -105,26 +105,18 @@ def app_model(
|
||||
|
||||
|
||||
class MockCeleryGroup:
|
||||
"""Mock for celery group() function that collects dispatched tasks.
|
||||
|
||||
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
|
||||
(e.g. producer) so production code can pass broker-related options without
|
||||
breaking tests.
|
||||
"""
|
||||
"""Mock for celery group() function that collects dispatched tasks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[dict[str, Any]] = []
|
||||
self._applied = False
|
||||
self.last_apply_async_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, items: Any) -> MockCeleryGroup:
|
||||
self.collected = list(items)
|
||||
return self
|
||||
|
||||
def apply_async(self, **kwargs: Any) -> None:
|
||||
# Accept arbitrary kwargs like producer to be compatible with Celery
|
||||
def apply_async(self) -> None:
|
||||
self._applied = True
|
||||
self.last_apply_async_kwargs = kwargs
|
||||
|
||||
@property
|
||||
def applied(self) -> bool:
|
||||
|
||||
@@ -1,817 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_auth import (
|
||||
DatasourceAuth,
|
||||
DatasourceAuthDefaultApi,
|
||||
DatasourceAuthDeleteApi,
|
||||
DatasourceAuthListApi,
|
||||
DatasourceAuthOauthCustomClient,
|
||||
DatasourceAuthUpdateApi,
|
||||
DatasourceHardCodeAuthListApi,
|
||||
DatasourceOAuthCallback,
|
||||
DatasourcePluginOAuthAuthorizationUrl,
|
||||
DatasourceUpdateProviderNameApi,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-1",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"create_proxy_context",
|
||||
return_value="ctx-123",
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_authorization_url",
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
|
||||
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {"name": "test"}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
def test_callback_missing_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_invalid_context(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "notion")
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {} # avatar + name missing
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=ctx"),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"reauthorize_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
assert "/oauth-callback" in response.location
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
oauth_response.expires_at = None
|
||||
oauth_response.metadata = {}
|
||||
|
||||
context = {
|
||||
"user_id": "user-1",
|
||||
"tenant_id": "tenant-1",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch.object(
|
||||
OAuthProxyService,
|
||||
"use_proxy_context",
|
||||
return_value=context,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
return_value={"client_id": "abc"},
|
||||
),
|
||||
patch.object(
|
||||
OAuthHandler,
|
||||
"get_credentials",
|
||||
return_value=oauth_response,
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_oauth_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
|
||||
assert response.status_code == 302
|
||||
|
||||
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
side_effect=CredentialsValidateFailedError("invalid"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
"enable_oauth_custom_client": False,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
"name": "x" * 101,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
@@ -1,143 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
|
||||
DataSourceContentPreviewApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDataSourceContentPreviewApi:
|
||||
def _valid_payload(self):
|
||||
return {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": "cred-1",
|
||||
}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
node_id = "node-1"
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
preview_result = {"content": "preview data"}
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = preview_result
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, node_id)
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once_with(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=payload["inputs"],
|
||||
account=account,
|
||||
datasource_type=payload["datasource_type"],
|
||||
is_published=True,
|
||||
credential_id=payload["credential_id"],
|
||||
)
|
||||
assert status == 200
|
||||
assert response == preview_result
|
||||
|
||||
def test_post_forbidden_non_account_user(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
MagicMock(), # NOT Account
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
# datasource_type missing
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node-1")
|
||||
|
||||
def test_post_without_credential_id(self, app):
|
||||
api = DataSourceContentPreviewApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"inputs": {"query": "hello"},
|
||||
"datasource_type": "notion",
|
||||
"credential_id": None,
|
||||
}
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
account = MagicMock(spec=Account)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.run_datasource_node_preview.return_value = {"ok": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
|
||||
account,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
|
||||
return_value=service_instance,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline, "node-1")
|
||||
|
||||
service_instance.run_datasource_node_preview.assert_called_once()
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
@@ -1,187 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
CustomizedPipelineTemplateApi,
|
||||
PipelineTemplateDetailApi,
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [{"id": "t1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in&language=en-US"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
|
||||
return_value=templates,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == templates
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
template = {"id": "tpl-1"}
|
||||
|
||||
service = MagicMock()
|
||||
service.get_pipeline_template_detail.return_value = template
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=built-in"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == template
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
|
||||
) as update_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert response == 200
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
|
||||
) as delete_mock,
|
||||
):
|
||||
response = method(api, "tpl-1")
|
||||
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
|
||||
def test_post_template_not_found(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"name": "Template",
|
||||
"description": "Desc",
|
||||
"icon_info": {"icon": "📘"},
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response = method(api, "pipeline-1")
|
||||
|
||||
service.publish_customized_pipeline_template.assert_called_once()
|
||||
assert response == {"result": "success"}
|
||||
@@ -1,187 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
|
||||
CreateEmptyRagPipelineDatasetApi,
|
||||
CreateRagPipelineDatasetApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
import_info = {"dataset_id": "ds-1"}
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == import_info
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_dataset_name_duplicate(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._valid_payload()
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_post_invalid_payload(self, app):
|
||||
api = CreateRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
|
||||
return_value={"id": "ds-1"},
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert response == {"id": "ds-1"}
|
||||
|
||||
def test_post_forbidden_non_editor(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
@@ -1,324 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
|
||||
RagPipelineEnvironmentVariableCollectionApi,
|
||||
RagPipelineNodeVariableCollectionApi,
|
||||
RagPipelineSystemVariableCollectionApi,
|
||||
RagPipelineVariableApi,
|
||||
RagPipelineVariableCollectionApi,
|
||||
RagPipelineVariableResetApi,
|
||||
)
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.types import SegmentType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_db():
|
||||
db = MagicMock()
|
||||
db.engine = MagicMock()
|
||||
db.session.return_value = MagicMock()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def editor_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.has_edit_permission = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restx_config(app):
|
||||
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
|
||||
|
||||
|
||||
class TestRagPipelineVariableCollectionApi:
|
||||
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = True
|
||||
|
||||
# IMPORTANT: RESTX expects .variables
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
draft_srv = MagicMock()
|
||||
draft_srv.list_variables_without_values.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=10"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=draft_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.is_workflow_exist.return_value = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_delete_variables_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableCollectionApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineNodeVariableCollectionApi:
|
||||
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_node_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_get_node_variables_invalid_node(self, app, editor_user):
|
||||
api = RagPipelineNodeVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
|
||||
|
||||
|
||||
class TestRagPipelineVariableApi:
|
||||
def test_get_variable_not_found(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, MagicMock(), "v1")
|
||||
|
||||
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(id="p1", tenant_id="t1")
|
||||
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
payload = {"value": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidArgumentError):
|
||||
method(api, pipeline, "v1")
|
||||
|
||||
def test_delete_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result.status_code == 204
|
||||
|
||||
|
||||
class TestRagPipelineVariableResetApi:
|
||||
def test_reset_variable_success(self, app, fake_db, editor_user):
|
||||
api = RagPipelineVariableResetApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
workflow = MagicMock()
|
||||
variable = MagicMock(app_id="p1")
|
||||
|
||||
srv = MagicMock()
|
||||
srv.get_variable.return_value = variable
|
||||
srv.reset_variable.return_value = variable
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
|
||||
return_value={"id": "v1"},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "v1")
|
||||
|
||||
assert result == {"id": "v1"}
|
||||
|
||||
|
||||
class TestSystemAndEnvironmentVariablesApi:
|
||||
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
|
||||
api = RagPipelineSystemVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
var_list = MagicMock()
|
||||
var_list.variables = []
|
||||
|
||||
srv = MagicMock()
|
||||
srv.list_system_variables.return_value = var_list
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
restx_config,
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
|
||||
return_value=srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == []
|
||||
|
||||
def test_environment_variables_success(self, app, editor_user):
|
||||
api = RagPipelineEnvironmentVariableCollectionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
env_var = MagicMock(
|
||||
id="e1",
|
||||
name="ENV",
|
||||
description="d",
|
||||
selector="s",
|
||||
value_type=MagicMock(value="string"),
|
||||
value="x",
|
||||
)
|
||||
|
||||
workflow = MagicMock(environment_variables=[env_var])
|
||||
pipeline = MagicMock(id="p1")
|
||||
|
||||
rag_srv = MagicMock()
|
||||
rag_srv.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
|
||||
return_value=rag_srv,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
@@ -1,329 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
RagPipelineImportApi,
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
RagPipelineImportConfirmApi,
|
||||
)
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
"yaml_content": "content",
|
||||
"name": "Test",
|
||||
}
|
||||
|
||||
def test_post_success_200(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"status": "success"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"status": "success"}
|
||||
|
||||
def test_post_failed_400(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"status": "failed"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert response == {"status": "failed"}
|
||||
|
||||
def test_post_pending_202(self, app):
|
||||
api = RagPipelineImportApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
result.model_dump.return_value = {"status": "pending"}
|
||||
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 202
|
||||
assert response == {"status": "pending"}
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
result.model_dump.return_value = {"ok": True}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"ok": True}
|
||||
|
||||
def test_confirm_failed(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
result.model_dump.return_value = {"ok": False}
|
||||
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response == {"ok": False}
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
result = MagicMock()
|
||||
result.model_dump.return_value = {"deps": []}
|
||||
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"deps": []}
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock(spec=Pipeline)
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": {"yaml": "data"}}
|
||||
@@ -1,688 +0,0 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
|
||||
DefaultRagPipelineBlockConfigApi,
|
||||
DraftRagPipelineApi,
|
||||
DraftRagPipelineRunApi,
|
||||
PublishedAllRagPipelineApi,
|
||||
PublishedRagPipelineApi,
|
||||
PublishedRagPipelineRunApi,
|
||||
RagPipelineByIdApi,
|
||||
RagPipelineDatasourceVariableApi,
|
||||
RagPipelineDraftNodeRunApi,
|
||||
RagPipelineDraftRunIterationNodeApi,
|
||||
RagPipelineDraftRunLoopNodeApi,
|
||||
RagPipelineRecommendedPluginApi,
|
||||
RagPipelineTaskStopApi,
|
||||
RagPipelineTransformApi,
|
||||
RagPipelineWorkflowLastRunApi,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == workflow
|
||||
|
||||
def test_get_draft_not_exist(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_hash_not_match(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"graph": {}, "features": {}}),
|
||||
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
method(api, pipeline)
|
||||
|
||||
def test_sync_invalid_text_plain(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, pipeline)
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node")
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_iteration_node_conversation_not_exists(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
|
||||
side_effect=services.errors.conversation.ConversationNotExistsError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
def test_loop_node_success(self, app):
|
||||
api = RagPipelineDraftRunLoopNodeApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline, "node") == {"ok": True}
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, pipeline) == {"ok": True}
|
||||
|
||||
def test_draft_run_rate_limit(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.run_draft_workflow_node.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"inputs": {}}),
|
||||
patch.object(type(console_ns), "payload", {"inputs": {}}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "node")
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert "created_at" in result
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
|
||||
) as stop_mock,
|
||||
):
|
||||
result = method(api, pipeline, "task-1")
|
||||
stop_mock.assert_called_once()
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_transform_forbidden(self, app):
|
||||
api = RagPipelineTransformApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds1")
|
||||
|
||||
def test_recommended_plugins(self, app):
|
||||
api = RagPipelineRecommendedPluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
service = MagicMock()
|
||||
service.get_recommended_plugins.return_value = [{"id": "p1"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?type=all"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result == [{"id": "p1"}]
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
"response_mode": "blocking",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_published_run_rate_limit(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"inputs": {},
|
||||
"datasource_type": "x",
|
||||
"datasource_info_list": [],
|
||||
"start_node_id": "n",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_default_block_config.return_value = {"k": "v"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?q={}"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "llm")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
def test_get_block_config_invalid_json(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
with app.test_request_context("/?q=bad-json"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, pipeline, "llm")
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
|
||||
assert result["items"] == [{"id": "w1"}]
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_published_workflows_forbidden(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?user_id=u2"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, pipeline)
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock(tenant_id="t1")
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "w1")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
def test_patch_no_fields(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, pipeline, "w1")
|
||||
assert status == 400
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
workflow = MagicMock()
|
||||
node_exec = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = workflow
|
||||
service.get_node_last_run.return_value = node_exec
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline, "node1")
|
||||
assert result == node_exec
|
||||
|
||||
def test_last_run_not_found(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
pipeline = MagicMock()
|
||||
|
||||
service = MagicMock()
|
||||
service.get_draft_workflow.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, pipeline, "node1")
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
payload = {
|
||||
"datasource_type": "db",
|
||||
"datasource_info": {},
|
||||
"start_node_id": "n1",
|
||||
"start_node_title": "Node",
|
||||
}
|
||||
|
||||
service = MagicMock()
|
||||
service.set_datasource_variables.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
),
|
||||
):
|
||||
result = method(api, pipeline)
|
||||
assert result is not None
|
||||
@@ -1,444 +0,0 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console.datasets import data_source
|
||||
from controllers.console.datasets.data_source import (
|
||||
DataSourceApi,
|
||||
DataSourceNotionApi,
|
||||
DataSourceNotionDatasetSyncApi,
|
||||
DataSourceNotionDocumentSyncApi,
|
||||
DataSourceNotionListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_ctx():
|
||||
return (MagicMock(id="u1"), "tenant-1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_tenant(tenant_ctx):
|
||||
with patch(
|
||||
"controllers.console.datasets.data_source.current_account_with_tenant",
|
||||
return_value=tenant_ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
with patch.object(
|
||||
type(data_source.db),
|
||||
"engine",
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestDataSourceApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
binding = MagicMock(
|
||||
id="b1",
|
||||
provider="notion",
|
||||
created_at="now",
|
||||
disabled=False,
|
||||
source_info={},
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: [binding]),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0]["is_bound"] is True
|
||||
|
||||
def test_get_no_bindings(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.db.session.scalars",
|
||||
return_value=MagicMock(all=lambda: []),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"] == []
|
||||
|
||||
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "enable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is False
|
||||
|
||||
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch("controllers.console.datasets.data_source.db.session.add"),
|
||||
patch("controllers.console.datasets.data_source.db.session.commit"),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "disable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is True
|
||||
|
||||
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "enable")
|
||||
|
||||
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "disable")
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
|
||||
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
page_name="Page 1",
|
||||
type="page",
|
||||
parent_id="parent",
|
||||
page_icon=None,
|
||||
)
|
||||
|
||||
online_document_message = MagicMock(
|
||||
result=[
|
||||
MagicMock(
|
||||
workspace_id="w1",
|
||||
workspace_name="My Workspace",
|
||||
workspace_icon="icon",
|
||||
pages=[page],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
dataset = MagicMock(data_source_type="notion_import")
|
||||
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
|
||||
patch(
|
||||
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
|
||||
return_value=MagicMock(
|
||||
get_online_document_pages=lambda **kw: iter([online_document_message]),
|
||||
datasource_provider_type=lambda: None,
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = [document]
|
||||
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
dataset = MagicMock(data_source_type="other_type")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.data_source.Session"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"integration_secret": "t"},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.NotionExtractor",
|
||||
return_value=extractor,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "p1", "page")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_indexing_estimate_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"notion_info_list": [
|
||||
{
|
||||
"workspace_id": "w1",
|
||||
"credential_id": "c1",
|
||||
"pages": [{"page_id": "p1", "type": "page"}],
|
||||
}
|
||||
],
|
||||
"process_rule": {"rules": {}},
|
||||
"doc_form": "text_model",
|
||||
"doc_language": "English",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
|
||||
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
|
||||
return_value=[MagicMock(id="d1")],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_document_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.data_source.DocumentService.get_document",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,399 +0,0 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.datasets.external import (
|
||||
BedrockRetrievalApi,
|
||||
ExternalApiTemplateApi,
|
||||
ExternalApiTemplateListApi,
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_external_dataset")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
user.is_dataset_editor = True
|
||||
user.has_edit_permission = True
|
||||
user.is_dataset_operator = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(mocker, current_user):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.external.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
api_item = MagicMock()
|
||||
api_item.to_dict.return_value = {"id": "1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
|
||||
def test_post_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
def test_post_duplicate_name(self, app):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalApiTemplateApi:
|
||||
def test_get_not_found(self, app):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_api",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "api-id")
|
||||
|
||||
def test_delete_forbidden(self, app, current_user):
|
||||
current_user.has_edit_permission = False
|
||||
current_user.is_dataset_operator = False
|
||||
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "api-id")
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
dataset.embedding_available = False
|
||||
dataset.built_in_field_enabled = False
|
||||
dataset.is_published = False
|
||||
dataset.enable_api = False
|
||||
dataset.enable_qa = False
|
||||
dataset.enable_vector_store = False
|
||||
dataset.vector_store_setting = None
|
||||
dataset.is_multimodal = False
|
||||
|
||||
dataset.retrieval_model_dict = {}
|
||||
dataset.tags = []
|
||||
dataset.external_knowledge_info = None
|
||||
dataset.external_retrieval_model = None
|
||||
dataset.doc_metadata = []
|
||||
dataset.icon_info = None
|
||||
|
||||
dataset.summary_index_setting = MagicMock()
|
||||
dataset.summary_index_setting.enable = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetService,
|
||||
"create_external_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
):
|
||||
_, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_create_forbidden(self, app, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
"external_knowledge_id": "kid",
|
||||
"name": "dataset",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApi:
|
||||
def test_hit_testing_dataset_not_found(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "dataset-id")
|
||||
|
||||
def test_hit_testing_success(self, app):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"query": "hello"}
|
||||
|
||||
dataset = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
||||
patch.object(DatasetService, "check_dataset_permission"),
|
||||
patch.object(
|
||||
HitTestingService,
|
||||
"external_retrieve",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp = method(api, "dataset-id")
|
||||
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestBedrockRetrievalApi:
|
||||
def test_bedrock_retrieval(self, app):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "hello",
|
||||
"knowledge_id": "kid",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
patch.object(
|
||||
ExternalDatasetTestService,
|
||||
"knowledge_retrieval",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp, status = method()
|
||||
|
||||
assert status == 200
|
||||
assert resp["ok"] is True
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApiAdvanced:
|
||||
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
|
||||
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
|
||||
def test_get_with_pagination(self, app, mock_auth, current_user):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&limit=20"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
),
|
||||
):
|
||||
resp, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
def test_create_forbidden(self, app, mock_auth, current_user):
|
||||
"""Test creating external dataset without permission"""
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
current_user.is_dataset_editor = False
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api-1",
|
||||
"external_knowledge_id": "ek-1",
|
||||
"name": "new_dataset",
|
||||
"description": "A dataset",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
|
||||
"""Test hit testing on non-existent dataset"""
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": None,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
|
||||
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
dataset = MagicMock()
|
||||
payload = {
|
||||
"query": "test query",
|
||||
"external_retrieval_model": {"type": "bm25"},
|
||||
"metadata_filtering_conditions": {"status": "active"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
|
||||
patch(
|
||||
"controllers.console.datasets.external.HitTestingService.external_retrieve",
|
||||
return_value={"results": []},
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
|
||||
assert resp["results"] == []
|
||||
|
||||
|
||||
class TestBedrockRetrievalApiAdvanced:
|
||||
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
"query": "test",
|
||||
"knowledge_id": "k-1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
|
||||
side_effect=ValueError("Invalid settings"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
@@ -1,160 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from controllers.console.datasets.hit_testing_base import HitTestingPayload
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_hit_testing")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass all decorators on the API method."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.login_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.account_initialization_required",
|
||||
return_value=lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
|
||||
return_value=lambda *_: (lambda f: f),
|
||||
)
|
||||
|
||||
|
||||
class TestHitTestingApi:
|
||||
def test_hit_testing_success(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "what is vector search",
|
||||
"top_k": 3,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"perform_hit_testing",
|
||||
return_value={"query": "what is vector search", "records": []},
|
||||
),
|
||||
):
|
||||
result = method(api, dataset_id)
|
||||
|
||||
assert "query" in result
|
||||
assert "records" in result
|
||||
assert result["records"] == []
|
||||
|
||||
def test_hit_testing_dataset_not_found(self, app, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
side_effect=NotFound("Dataset not found"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
|
||||
api = HitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingPayload,
|
||||
"model_validate",
|
||||
return_value=MagicMock(model_dump=lambda **_: payload),
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"get_and_validate_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
HitTestingApi,
|
||||
"hit_testing_args_check",
|
||||
side_effect=ValueError("Invalid parameters"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Invalid parameters"):
|
||||
method(api, dataset_id)
|
||||
@@ -1,207 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.console.app.error import (
|
||||
CompletionRequestError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||
from controllers.console.datasets.hit_testing_base import (
|
||||
DatasetsHitTestingBase,
|
||||
)
|
||||
from core.errors.error import (
|
||||
LLMBadRequestError,
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from models.account import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account():
|
||||
acc = MagicMock(spec=Account)
|
||||
return acc
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_current_user(mocker, account):
|
||||
"""Patch current_user to a valid Account."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.hit_testing_base.current_user",
|
||||
account,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
return MagicMock(id="dataset-1")
|
||||
|
||||
|
||||
class TestGetAndValidateDataset:
|
||||
def test_success(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
):
|
||||
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
assert result == dataset
|
||||
|
||||
def test_dataset_not_found(self):
|
||||
with patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
def test_permission_denied(self, dataset):
|
||||
with (
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
side_effect=services.errors.account.NoPermissionError("no access"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden, match="no access"):
|
||||
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
|
||||
|
||||
|
||||
class TestHitTestingArgsCheck:
|
||||
def test_args_check_called(self):
|
||||
args = {"query": "test"}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"hit_testing_args_check",
|
||||
) as check_mock:
|
||||
DatasetsHitTestingBase.hit_testing_args_check(args)
|
||||
|
||||
check_mock.assert_called_once_with(args)
|
||||
|
||||
|
||||
class TestParseArgs:
|
||||
def test_parse_args_success(self):
|
||||
payload = {"query": "hello"}
|
||||
|
||||
result = DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
assert result["query"] == "hello"
|
||||
|
||||
def test_parse_args_invalid(self):
|
||||
payload = {"query": "x" * 300}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
DatasetsHitTestingBase.parse_args(payload)
|
||||
|
||||
|
||||
class TestPerformHitTesting:
|
||||
def test_success(self, dataset):
|
||||
response = {
|
||||
"query": "hello",
|
||||
"records": [],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
return_value=response,
|
||||
):
|
||||
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
assert result["query"] == "hello"
|
||||
assert result["records"] == []
|
||||
|
||||
def test_index_not_initialized(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=services.errors.index.IndexNotInitializedError(),
|
||||
):
|
||||
with pytest.raises(DatasetNotInitializedError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_provider_token_not_init(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ProviderTokenNotInitError("token missing"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_quota_exceeded(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=QuotaExceededError(),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_model_not_supported(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_llm_bad_request(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=LLMBadRequestError("bad request"),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_invoke_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=InvokeError("invoke failed"),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_value_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=ValueError("bad args"),
|
||||
):
|
||||
with pytest.raises(ValueError, match="bad args"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
|
||||
def test_unexpected_error(self, dataset):
|
||||
with patch.object(
|
||||
HitTestingService,
|
||||
"retrieve",
|
||||
side_effect=Exception("boom"),
|
||||
):
|
||||
with pytest.raises(InternalServerError, match="boom"):
|
||||
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
|
||||
@@ -1,362 +0,0 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.metadata import (
|
||||
DatasetMetadataApi,
|
||||
DatasetMetadataBuiltInFieldActionApi,
|
||||
DatasetMetadataBuiltInFieldApi,
|
||||
DatasetMetadataCreateApi,
|
||||
DocumentMetadataEditApi,
|
||||
)
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
MetadataArgs,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_dataset_metadata")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
user.id = "user-1"
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset():
|
||||
ds = MagicMock()
|
||||
ds.id = "dataset-1"
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def metadata_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_decorators(mocker):
|
||||
"""Bypass setup/login/license decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.metadata.enterprise_license_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestDatasetMetadataCreateApi:
|
||||
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "author"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"create_metadata",
|
||||
return_value={"id": "m1", "name": "author"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["name"] == "author"
|
||||
|
||||
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
valid_payload = {
|
||||
"type": "string",
|
||||
"name": "author",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=valid_payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
MetadataArgs,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataGetApi:
|
||||
def test_get_metadata_success(self, app, dataset, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_dataset_metadatas",
|
||||
return_value=[{"id": "m1"}],
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_metadata_dataset_not_found(self, app, dataset_id):
|
||||
api = DatasetMetadataCreateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, dataset_id)
|
||||
|
||||
|
||||
class TestDatasetMetadataApi:
|
||||
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {"name": "updated-name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_metadata_name",
|
||||
return_value={"id": "m1", "name": "updated-name"},
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "updated-name"
|
||||
|
||||
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
|
||||
api = DatasetMetadataApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"delete_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, metadata_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldApi:
|
||||
def test_get_built_in_fields(self, app):
|
||||
api = DatasetMetadataBuiltInFieldApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"get_built_in_fields",
|
||||
return_value=["title", "source"],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["fields"] == ["title", "source"]
|
||||
|
||||
|
||||
class TestDatasetMetadataBuiltInFieldActionApi:
|
||||
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
|
||||
api = DatasetMetadataBuiltInFieldActionApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"enable_built_in_field",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id, "enable")
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestDocumentMetadataEditApi:
|
||||
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
|
||||
api = DocumentMetadataEditApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"operation": "add", "metadata": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.metadata.current_account_with_tenant",
|
||||
return_value=(current_user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"get_dataset",
|
||||
return_value=dataset,
|
||||
),
|
||||
patch.object(
|
||||
DatasetService,
|
||||
"check_dataset_permission",
|
||||
),
|
||||
patch.object(
|
||||
MetadataOperationData,
|
||||
"model_validate",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch.object(
|
||||
MetadataService,
|
||||
"update_documents_metadata",
|
||||
),
|
||||
):
|
||||
result, status = method(api, dataset_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
@@ -1,233 +0,0 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.datasets.website import (
|
||||
WebsiteCrawlApi,
|
||||
WebsiteCrawlStatusApi,
|
||||
)
|
||||
from services.website_service import (
|
||||
WebsiteCrawlApiRequest,
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
WebsiteService,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
"""Recursively unwrap decorated functions."""
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask("test_website_crawl")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bypass_auth_and_setup(mocker):
|
||||
"""Bypass setup/login/account decorators."""
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.login_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.setup_required",
|
||||
lambda f: f,
|
||||
)
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.account_initialization_required",
|
||||
lambda f: f,
|
||||
)
|
||||
|
||||
|
||||
class TestWebsiteCrawlApi:
|
||||
def test_crawl_success(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {"depth": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
return_value={"job_id": "job-1"},
|
||||
)
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["job_id"] == "job-1"
|
||||
|
||||
def test_crawl_invalid_payload(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "bad-url",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid payload"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
|
||||
method(api)
|
||||
|
||||
def test_crawl_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"provider": "firecrawl",
|
||||
"url": "https://example.com",
|
||||
"options": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
):
|
||||
mock_request = Mock(spec=WebsiteCrawlApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"crawl_url",
|
||||
side_effect=Exception("crawl failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWebsiteCrawlStatusApi:
|
||||
def test_get_status_success(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
return_value={"status": "completed"},
|
||||
)
|
||||
|
||||
result, status = method(api, job_id)
|
||||
|
||||
assert status == 200
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_get_status_invalid_provider(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
side_effect=ValueError("invalid provider"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
|
||||
method(api, job_id)
|
||||
|
||||
def test_get_status_service_error(self, app, mocker):
|
||||
api = WebsiteCrawlStatusApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
job_id = "job-123"
|
||||
args = {"provider": "firecrawl"}
|
||||
|
||||
with app.test_request_context("/?provider=firecrawl"):
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.website.request.args.to_dict",
|
||||
return_value=args,
|
||||
)
|
||||
|
||||
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
|
||||
mocker.patch.object(
|
||||
WebsiteCrawlStatusApiRequest,
|
||||
"from_args",
|
||||
return_value=mock_request,
|
||||
)
|
||||
|
||||
mocker.patch.object(
|
||||
WebsiteService,
|
||||
"get_crawl_status_typed",
|
||||
side_effect=Exception("status lookup failed"),
|
||||
)
|
||||
|
||||
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
|
||||
method(api, job_id)
|
||||
@@ -1,117 +0,0 @@
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from models.dataset import Pipeline
|
||||
|
||||
|
||||
class TestGetRagPipeline:
|
||||
def test_missing_pipeline_id(self):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(ValueError, match="missing pipeline_id"):
|
||||
dummy_view()
|
||||
|
||||
def test_pipeline_not_found(self, mocker):
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = None
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
with pytest.raises(PipelineNotFoundError):
|
||||
dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
def test_pipeline_found_and_injected(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
pipeline.id = "pipeline-1"
|
||||
pipeline.tenant_id = "tenant-1"
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result is pipeline
|
||||
|
||||
def test_pipeline_id_removed_from_kwargs(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
assert "pipeline_id" not in kwargs
|
||||
return "ok"
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.return_value.first.return_value = pipeline
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id="pipeline-1")
|
||||
|
||||
assert result == "ok"
|
||||
|
||||
def test_pipeline_id_cast_to_string(self, mocker):
|
||||
pipeline = Mock(spec=Pipeline)
|
||||
|
||||
@get_rag_pipeline
|
||||
def dummy_view(**kwargs):
|
||||
return kwargs["pipeline"]
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.current_account_with_tenant",
|
||||
return_value=(Mock(), "tenant-1"),
|
||||
)
|
||||
|
||||
def where_side_effect(*args, **kwargs):
|
||||
assert args[0].right.value == "123"
|
||||
return Mock(first=lambda: pipeline)
|
||||
|
||||
mock_query = Mock()
|
||||
mock_query.where.side_effect = where_side_effect
|
||||
|
||||
mocker.patch(
|
||||
"controllers.console.datasets.wraps.db.session.query",
|
||||
return_value=mock_query,
|
||||
)
|
||||
|
||||
result = dummy_view(pipeline_id=123)
|
||||
|
||||
assert result is pipeline
|
||||
@@ -1,341 +0,0 @@
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
EmailAlreadyInUseError,
|
||||
EmailCodeError,
|
||||
)
|
||||
from controllers.console.error import AccountInFreezeError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountAvatarApi,
|
||||
AccountDeleteApi,
|
||||
AccountDeleteVerifyApi,
|
||||
AccountInitApi,
|
||||
AccountIntegrateApi,
|
||||
AccountInterfaceLanguageApi,
|
||||
AccountInterfaceThemeApi,
|
||||
AccountNameApi,
|
||||
AccountPasswordApi,
|
||||
AccountProfileApi,
|
||||
AccountTimezoneApi,
|
||||
ChangeEmailCheckApi,
|
||||
ChangeEmailResetApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from controllers.console.workspace.error import (
|
||||
AccountAlreadyInitedError,
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidAccountDeletionCodeError,
|
||||
)
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAccountInitApi:
|
||||
def test_init_success(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="inactive")
|
||||
payload = {
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
"invitation_code": "code123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_init_already_initialized(self, app):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="active")
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
):
|
||||
with pytest.raises(AccountAlreadyInitedError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountProfileApi:
|
||||
def test_get_profile_success(self, app):
|
||||
api = AccountProfileApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/profile"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountUpdateApis:
|
||||
@pytest.mark.parametrize(
|
||||
("api_cls", "payload"),
|
||||
[
|
||||
(AccountNameApi, {"name": "test"}),
|
||||
(AccountAvatarApi, {"avatar": "img.png"}),
|
||||
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
|
||||
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
|
||||
(AccountTimezoneApi, {"timezone": "UTC"}),
|
||||
],
|
||||
)
|
||||
def test_update_success(self, app, api_cls, payload):
|
||||
api = api_cls()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
|
||||
class TestAccountPasswordApi:
|
||||
def test_password_success(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "old",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
def test_password_wrong_current(self, app):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "bad",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.update_account_password",
|
||||
side_effect=ServicePwdError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CurrentPasswordIncorrectError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAccountIntegrateApi:
|
||||
def test_get_integrates(self, app):
|
||||
api = AccountIntegrateApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = MagicMock(id="acc1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
|
||||
):
|
||||
scalars_mock.return_value.all.return_value = []
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
|
||||
class TestAccountDeleteApi:
|
||||
def test_delete_verify_success(self, app):
|
||||
api = AccountDeleteVerifyApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
|
||||
return_value=("token", "1234"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_invalid_code(self, app):
|
||||
api = AccountDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"token": "t", "code": "x"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidAccountDeletionCodeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestChangeEmailApis:
|
||||
def test_check_email_code_invalid(self, app):
|
||||
api = ChangeEmailCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "a@test.com", "code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.get_change_email_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_reset_email_already_used(self, app):
|
||||
api = ChangeEmailResetApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"new_email": "x@test.com", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
|
||||
):
|
||||
with pytest.raises(EmailAlreadyInUseError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCheckEmailUniqueApi:
|
||||
def test_email_unique_success(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "ok@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_email_in_freeze(self, app):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"email": "x@test.com"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
new_callable=PropertyMock,
|
||||
return_value=payload,
|
||||
),
|
||||
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
|
||||
):
|
||||
with pytest.raises(AccountInFreezeError):
|
||||
method(api)
|
||||
@@ -1,139 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.error import AccountNotFound
|
||||
from controllers.console.workspace.agent_providers import (
|
||||
AgentProviderApi,
|
||||
AgentProviderListApi,
|
||||
)
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestAgentProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
providers = [{"name": "openai"}, {"name": "anthropic"}]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=providers,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == providers
|
||||
|
||||
def test_get_empty_list(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestAgentProviderApi:
|
||||
def test_get_success(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "openai"
|
||||
provider_data = {"name": "openai", "models": ["gpt-4"]}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=provider_data,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result == provider_data
|
||||
|
||||
def test_get_provider_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user1")
|
||||
tenant_id = "tenant1"
|
||||
provider_name = "unknown"
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider_name)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_account_not_found(self, app):
|
||||
api = AgentProviderApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.agent_providers.current_account_with_tenant",
|
||||
side_effect=AccountNotFound(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(AccountNotFound):
|
||||
method(api, "openai")
|
||||
@@ -1,305 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.workspace.endpoint import (
|
||||
EndpointCreateApi,
|
||||
EndpointDeleteApi,
|
||||
EndpointDisableApi,
|
||||
EndpointEnableApi,
|
||||
EndpointListApi,
|
||||
EndpointListForSinglePluginApi,
|
||||
EndpointUpdateApi,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_and_tenant():
|
||||
return MagicMock(id="u1"), "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_current_account(user_and_tenant):
|
||||
with patch(
|
||||
"controllers.console.workspace.endpoint.current_account_with_tenant",
|
||||
return_value=user_and_tenant,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointCreateApi:
|
||||
def test_create_success(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {"a": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_permission_denied(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
"name": "endpoint",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
|
||||
side_effect=PluginPermissionDeniedError("denied"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_create_validation_error(self, app):
|
||||
api = EndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "p1",
|
||||
"name": "",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListApi:
|
||||
def test_list_success(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
assert len(result["endpoints"]) == 1
|
||||
|
||||
def test_list_invalid_query(self, app):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=0&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListForSinglePluginApi:
|
||||
def test_list_for_plugin_success(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
|
||||
patch(
|
||||
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
|
||||
return_value=[{"id": "e1"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "endpoints" in result
|
||||
|
||||
def test_list_for_plugin_missing_param(self, app):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDeleteApi:
|
||||
def test_delete_success(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_delete_invalid_payload(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_delete_service_failure(self, app):
|
||||
api = EndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "new-name",
|
||||
"settings": {"x": 1},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_update_validation_error(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1", "settings": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_update_service_failure(self, app):
|
||||
api = EndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
"name": "n",
|
||||
"settings": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointEnableApi:
|
||||
def test_enable_success(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_enable_invalid_payload(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
def test_enable_service_failure(self, app):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDisableApi:
|
||||
def test_disable_success(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_disable_invalid_payload(self, app):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@@ -1,607 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
import services
|
||||
from controllers.console.auth.error import (
|
||||
CannotTransferOwnerToSelfError,
|
||||
EmailCodeError,
|
||||
InvalidEmailError,
|
||||
InvalidTokenError,
|
||||
MemberNotInTenantError,
|
||||
NotOwnerError,
|
||||
OwnerTransferLimitError,
|
||||
)
|
||||
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
|
||||
from controllers.console.workspace.members import (
|
||||
DatasetOperatorMemberListApi,
|
||||
MemberCancelInviteApi,
|
||||
MemberInviteEmailApi,
|
||||
MemberListApi,
|
||||
MemberUpdateRoleApi,
|
||||
OwnerTransfer,
|
||||
OwnerTransferCheckApi,
|
||||
SendOwnerTransferEmailApi,
|
||||
)
|
||||
from services.errors.account import AccountAlreadyInTenantError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "m1"
|
||||
member.name = "Member"
|
||||
member.email = "member@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "admin"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = MemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestMemberInviteEmailApi:
|
||||
def test_invite_success(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
"language": "en-US",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invite_limit_exceeded(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = False
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
with pytest.raises(WorkspaceMembersLimitExceeded):
|
||||
method(api)
|
||||
|
||||
def test_invite_already_member(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=AccountAlreadyInTenantError(),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "success"
|
||||
|
||||
def test_invite_invalid_role(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "owner",
|
||||
}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 400
|
||||
assert result["code"] == "invalid-role"
|
||||
|
||||
def test_invite_generic_exception(self, app):
|
||||
api = MemberInviteEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
features = MagicMock()
|
||||
features.workspace_members.is_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"emails": ["a@test.com"],
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
|
||||
patch(
|
||||
"controllers.console.workspace.members.RegisterService.invite_new_member",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
|
||||
):
|
||||
result, _ = method(api)
|
||||
|
||||
assert result["invitation_results"][0]["status"] == "failed"
|
||||
|
||||
|
||||
class TestMemberCancelInviteApi:
|
||||
def test_cancel_success(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 200
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_cancel_not_found(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "x")
|
||||
|
||||
def test_cancel_cannot_operate_self(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.CannotOperateSelfError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_cancel_no_permission(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.NoPermissionError("x"),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 403
|
||||
|
||||
def test_cancel_member_not_in_tenant(self, app):
|
||||
api = MemberCancelInviteApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
tenant = MagicMock(id="t1")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.query") as q,
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
|
||||
side_effect=services.errors.account.MemberNotInTenantError(),
|
||||
),
|
||||
):
|
||||
q.return_value.where.return_value.first.return_value = member
|
||||
result, status = method(api, member.id)
|
||||
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestMemberUpdateRoleApi:
|
||||
def test_update_success(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.update_member_role"),
|
||||
):
|
||||
result = method(api, "id")
|
||||
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_update_invalid_role(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "invalid-role"}
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
result, status = method(api, "id")
|
||||
|
||||
assert status == 400
|
||||
|
||||
def test_update_member_not_found(self, app):
|
||||
api = MemberUpdateRoleApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"role": "normal"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.members.current_account_with_tenant",
|
||||
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=None),
|
||||
):
|
||||
with pytest.raises(HTTPException):
|
||||
method(api, "id")
|
||||
|
||||
|
||||
class TestDatasetOperatorMemberListApi:
|
||||
def test_get_success(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
member.id = "op1"
|
||||
member.name = "Operator"
|
||||
member.email = "operator@test.com"
|
||||
member.avatar = "avatar.png"
|
||||
member.role = "operator"
|
||||
member.status = "active"
|
||||
members = [member]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["accounts"]) == 1
|
||||
|
||||
def test_get_no_tenant(self, app):
|
||||
api = DatasetOperatorMemberListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(current_tenant=None)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestSendOwnerTransferEmailApi:
|
||||
def test_send_success(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(name="ws")
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_send_ip_limit(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
|
||||
):
|
||||
with pytest.raises(EmailSendIpLimitError):
|
||||
method(api)
|
||||
|
||||
def test_send_not_owner(self, app):
|
||||
api = SendOwnerTransferEmailApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
|
||||
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
|
||||
):
|
||||
with pytest.raises(NotOwnerError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferCheckApi:
|
||||
def test_check_invalid_code(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com", "code": "y"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
|
||||
def test_rate_limited(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
with pytest.raises(OwnerTransferLimitError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_email(self, app):
|
||||
api = OwnerTransferCheckApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"code": "x", "token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "b@test.com", "code": "x"},
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidEmailError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestOwnerTransferApi:
|
||||
def test_transfer_self(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
):
|
||||
with pytest.raises(CannotTransferOwnerToSelfError):
|
||||
method(api, "1")
|
||||
|
||||
def test_invalid_token(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
|
||||
):
|
||||
with pytest.raises(InvalidTokenError):
|
||||
method(api, "2")
|
||||
|
||||
def test_member_not_in_tenant(self, app):
|
||||
api = OwnerTransfer()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
|
||||
member = MagicMock()
|
||||
|
||||
payload = {"token": "t"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
|
||||
patch(
|
||||
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
|
||||
return_value={"email": "a@test.com"},
|
||||
),
|
||||
patch("controllers.console.workspace.members.db.session.get", return_value=member),
|
||||
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
|
||||
):
|
||||
with pytest.raises(MemberNotInTenantError):
|
||||
method(api, "2")
|
||||
@@ -1,388 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic_core import ValidationError
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.model_providers import (
|
||||
ModelProviderCredentialApi,
|
||||
ModelProviderCredentialSwitchApi,
|
||||
ModelProviderIconApi,
|
||||
ModelProviderListApi,
|
||||
ModelProviderPaymentCheckoutUrlApi,
|
||||
ModelProviderValidateApi,
|
||||
PreferredProviderTypeUpdateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
VALID_UUID = "123e4567-e89b-12d3-a456-426614174000"
|
||||
INVALID_UUID = "123"
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestModelProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?model_type=llm"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_list",
|
||||
return_value=[{"name": "openai"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderCredentialApi:
|
||||
def test_get_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={VALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential",
|
||||
return_value={"key": "value"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_get_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/?credential_id={INVALID_UUID}"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_post_create_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}, "name": "test"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 201
|
||||
|
||||
def test_post_create_validation_error(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_put_update_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_put_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.put)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result, status = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderCredentialSwitchApi:
|
||||
def test_switch_success(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": VALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_invalid_uuid(self, app):
|
||||
api = ModelProviderCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": INVALID_UUID}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderValidateApi:
|
||||
def test_validate_success(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_validate_failure(self, app):
|
||||
api = ModelProviderValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"a": "b"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials",
|
||||
side_effect=CredentialsValidateFailedError("bad"),
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestModelProviderIconApi:
|
||||
def test_icon_success(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(b"123", "image/png"),
|
||||
),
|
||||
):
|
||||
response = api.get("t1", "openai", "logo", "en")
|
||||
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_icon_not_found(self, app):
|
||||
api = ModelProviderIconApi()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon",
|
||||
return_value=(None, None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
api.get("t1", "openai", "logo", "en")
|
||||
|
||||
|
||||
class TestPreferredProviderTypeUpdateApi:
|
||||
def test_update_success(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "custom"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api, provider="openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_invalid_enum(self, app):
|
||||
api = PreferredProviderTypeUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"preferred_provider_type": "invalid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
method(api, provider="openai")
|
||||
|
||||
|
||||
class TestModelProviderPaymentCheckoutUrlApi:
|
||||
def test_checkout_success(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
return_value=None,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link",
|
||||
return_value={"url": "x"},
|
||||
),
|
||||
):
|
||||
result = method(api, provider="anthropic")
|
||||
|
||||
assert "url" in result
|
||||
|
||||
def test_invalid_provider(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, provider="openai")
|
||||
|
||||
def test_permission_denied(self, app):
|
||||
api = ModelProviderPaymentCheckoutUrlApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="u1", email="x@test.com")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.current_account_with_tenant",
|
||||
return_value=(user, "tenant1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin",
|
||||
side_effect=Forbidden(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, provider="anthropic")
|
||||
@@ -1,447 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.workspace.models import (
|
||||
DefaultModelApi,
|
||||
ModelProviderAvailableModelApi,
|
||||
ModelProviderModelApi,
|
||||
ModelProviderModelCredentialApi,
|
||||
ModelProviderModelCredentialSwitchApi,
|
||||
ModelProviderModelDisableApi,
|
||||
ModelProviderModelEnableApi,
|
||||
ModelProviderModelParameterRuleApi,
|
||||
ModelProviderModelValidateApi,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDefaultModelApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={"model_type": ModelType.LLM.value},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"}
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model_settings": [
|
||||
{
|
||||
"model_type": ModelType.LLM.value,
|
||||
"provider": "openai",
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_get_returns_empty_when_no_default(self, app):
|
||||
api = DefaultModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_default_model_of_model_type.return_value = None
|
||||
|
||||
result = method(api)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelApi:
|
||||
def test_get_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_post_models_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"load_balancing": {
|
||||
"configs": [{"weight": 1}],
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_model_success(self, app: Flask):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
def test_get_models_returns_empty(self, app):
|
||||
api = ModelProviderModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_provider.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialApi:
|
||||
def test_get_credentials_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/",
|
||||
query_string={
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as provider_service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service,
|
||||
):
|
||||
provider_service.return_value.get_model_credential.return_value = {
|
||||
"credentials": {},
|
||||
"current_credential_id": None,
|
||||
"current_credential_name": None,
|
||||
}
|
||||
provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb_service.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "credentials" in result
|
||||
|
||||
def test_create_credential_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_get_empty_credentials(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb,
|
||||
):
|
||||
service.return_value.get_model_credential.return_value = None
|
||||
service.return_value.provider_manager.get_provider_model_available_credentials.return_value = []
|
||||
lb.return_value.get_load_balancing_configs.return_value = (False, [])
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["credentials"] == {}
|
||||
|
||||
def test_delete_success(self, app):
|
||||
api = ModelProviderModelCredentialApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
payload = {
|
||||
"model": "gpt",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result, status = method(api, "openai")
|
||||
|
||||
assert status == 204
|
||||
|
||||
|
||||
class TestModelProviderModelCredentialSwitchApi:
|
||||
def test_switch_success(self, app: Flask):
|
||||
api = ModelProviderModelCredentialSwitchApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credential_id": "abc",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelEnableDisableApis:
|
||||
def test_enable_model(self, app: Flask):
|
||||
api = ModelProviderModelEnableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_disable_model(self, app: Flask):
|
||||
api = ModelProviderModelDisableApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestModelProviderModelValidateApi:
|
||||
def test_validate_success(self, app: Flask):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": "gpt-4",
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {"key": "val"},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService"),
|
||||
):
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["gpt-4", "gpt"])
|
||||
def test_validate_failure(self, app: Flask, model_name: str):
|
||||
api = ModelProviderModelValidateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"model_type": ModelType.LLM.value,
|
||||
"credentials": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid")
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["result"] == "error"
|
||||
|
||||
|
||||
class TestParameterAndAvailableModels:
|
||||
def test_parameter_rules(self, app: Flask):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt-4"}),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_available_models(self, app: Flask):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.models.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant1"),
|
||||
),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service_mock,
|
||||
):
|
||||
service_mock.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert "data" in result
|
||||
|
||||
def test_empty_rules(self, app):
|
||||
api = ModelProviderModelParameterRuleApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={"model": "gpt"}),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_model_parameter_rules.return_value = []
|
||||
|
||||
result = method(api, "openai")
|
||||
|
||||
assert result["data"] == []
|
||||
|
||||
def test_no_models(self, app):
|
||||
api = ModelProviderAvailableModelApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")),
|
||||
patch("controllers.console.workspace.models.ModelProviderService") as service,
|
||||
):
|
||||
service.return_value.get_models_by_model_type.return_value = []
|
||||
|
||||
result = method(api, ModelType.LLM.value)
|
||||
|
||||
assert result["data"] == []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,52 +4,16 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
ToolApiListApi,
|
||||
ToolApiProviderAddApi,
|
||||
ToolApiProviderDeleteApi,
|
||||
ToolApiProviderGetApi,
|
||||
ToolApiProviderGetRemoteSchemaApi,
|
||||
ToolApiProviderListToolsApi,
|
||||
ToolApiProviderUpdateApi,
|
||||
ToolBuiltinListApi,
|
||||
ToolBuiltinProviderAddApi,
|
||||
ToolBuiltinProviderCredentialsSchemaApi,
|
||||
ToolBuiltinProviderDeleteApi,
|
||||
ToolBuiltinProviderGetCredentialInfoApi,
|
||||
ToolBuiltinProviderGetCredentialsApi,
|
||||
ToolBuiltinProviderGetOauthClientSchemaApi,
|
||||
ToolBuiltinProviderIconApi,
|
||||
ToolBuiltinProviderInfoApi,
|
||||
ToolBuiltinProviderListToolsApi,
|
||||
ToolBuiltinProviderSetDefaultApi,
|
||||
ToolBuiltinProviderUpdateApi,
|
||||
ToolLabelsApi,
|
||||
ToolOAuthCallback,
|
||||
ToolOAuthCustomClient,
|
||||
ToolPluginOAuthApi,
|
||||
ToolProviderListApi,
|
||||
ToolProviderMCPApi,
|
||||
ToolWorkflowListApi,
|
||||
ToolWorkflowProviderCreateApi,
|
||||
ToolWorkflowProviderDeleteApi,
|
||||
ToolWorkflowProviderGetApi,
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
@@ -143,602 +107,3 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
||||
|
||||
class TestUtils:
|
||||
def test_is_valid_url(self):
|
||||
assert is_valid_url("https://example.com")
|
||||
assert is_valid_url("http://example.com")
|
||||
assert not is_valid_url("")
|
||||
assert not is_valid_url("ftp://example.com")
|
||||
assert not is_valid_url("not-a-url")
|
||||
assert not is_valid_url(None)
|
||||
|
||||
|
||||
class TestToolProviderListApi:
|
||||
def test_get_success(self, app):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers",
|
||||
return_value=["p1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["p1"]
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
def test_list_tools(self, app):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools",
|
||||
return_value=[{"a": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == [{"a": 1}]
|
||||
|
||||
def test_info(self, app):
|
||||
api = ToolBuiltinProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"x": 1}
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolBuiltinProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_id": "cid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["result"] == "success"
|
||||
|
||||
def test_add_invalid_type(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}, "type": "invalid"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "provider")
|
||||
|
||||
def test_add_success(self, app):
|
||||
api = ToolBuiltinProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {}, "type": "oauth2", "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["id"] == 1
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolBuiltinProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "c1", "credentials": {}, "name": "n"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credentials(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials",
|
||||
return_value={"k": "v"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"k": "v"}
|
||||
|
||||
def test_icon(self, app):
|
||||
api = ToolBuiltinProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon",
|
||||
return_value=(b"x", "image/png"),
|
||||
),
|
||||
):
|
||||
response = method(api, "provider")
|
||||
assert response.mimetype == "image/png"
|
||||
|
||||
def test_credentials_schema(self, app):
|
||||
api = ToolBuiltinProviderCredentialsSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider", "oauth2") == {"schema": {}}
|
||||
|
||||
def test_set_default_credential(self, app):
|
||||
api = ToolBuiltinProviderSetDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"id": "c1"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_credential_info(self, app):
|
||||
api = ToolBuiltinProviderGetCredentialInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info",
|
||||
return_value={"info": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"info": "x"}
|
||||
|
||||
def test_get_oauth_client_schema(self, app):
|
||||
api = ToolBuiltinProviderGetOauthClientSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema",
|
||||
return_value={"schema": {}},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"schema": {}}
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
def test_add(self, app):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_remote_schema(self, app):
|
||||
api = ToolApiProviderGetRemoteSchemaApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?url=http://x.com"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema",
|
||||
return_value={"schema": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api)["schema"] == "x"
|
||||
|
||||
def test_list_tools(self, app):
|
||||
api = ToolApiProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools",
|
||||
return_value=[{"tool": 1}],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"tool": 1}]
|
||||
|
||||
def test_update(self, app):
|
||||
api = ToolApiProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credentials": {},
|
||||
"schema_type": "openapi",
|
||||
"schema": "{}",
|
||||
"provider": "p",
|
||||
"original_provider": "o",
|
||||
"icon": {},
|
||||
"privacy_policy": "",
|
||||
"custom_disclaimer": "",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolApiProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"provider": "p"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider",
|
||||
return_value={"result": "success"},
|
||||
),
|
||||
):
|
||||
assert method(api)["result"] == "success"
|
||||
|
||||
def test_get(self, app):
|
||||
api = ToolApiProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?provider=p"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider",
|
||||
return_value={"x": 1},
|
||||
),
|
||||
):
|
||||
assert method(api) == {"x": 1}
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
def test_create(self, app):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_app_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "n",
|
||||
"label": "l",
|
||||
"description": "d",
|
||||
"icon": {},
|
||||
"parameters": [],
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool",
|
||||
return_value={"id": 1},
|
||||
),
|
||||
):
|
||||
assert method(api)["id"] == 1
|
||||
|
||||
def test_update_invalid(self, app):
|
||||
api = ToolWorkflowProviderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"name": "Tool",
|
||||
"label": "Tool Label",
|
||||
"description": "A tool",
|
||||
"icon": {},
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
assert result["ok"]
|
||||
|
||||
def test_delete(self, app):
|
||||
api = ToolWorkflowProviderDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api)["ok"]
|
||||
|
||||
def test_get_error(self, app):
|
||||
api = ToolWorkflowProviderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestLists:
|
||||
def test_builtin_list(self, app):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_api_list(self, app):
|
||||
api = ToolApiListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(None, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
def test_workflow_list(self, app):
|
||||
api = ToolWorkflowListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
m = MagicMock()
|
||||
m.to_dict.return_value = {"x": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools",
|
||||
return_value=[m],
|
||||
),
|
||||
):
|
||||
assert method(api) == [{"x": 1}]
|
||||
|
||||
|
||||
class TestLabels:
|
||||
def test_labels(self, app):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels",
|
||||
return_value=["l1"],
|
||||
),
|
||||
):
|
||||
assert method(api) == ["l1"]
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
def test_oauth_no_client(self, app):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u"), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
def test_oauth_callback_no_cookie(self, app):
|
||||
api = ToolOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "provider")
|
||||
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
def test_save_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"client_params": {"a": 1}}),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
def test_get_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params",
|
||||
return_value={"client_id": "x"},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider") == {"client_id": "x"}
|
||||
|
||||
def test_delete_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "provider")["ok"]
|
||||
|
||||
@@ -1,558 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden
|
||||
|
||||
from controllers.console.workspace.trigger_providers import (
|
||||
TriggerOAuthAuthorizeApi,
|
||||
TriggerOAuthCallbackApi,
|
||||
TriggerOAuthClientManageApi,
|
||||
TriggerProviderIconApi,
|
||||
TriggerProviderInfoApi,
|
||||
TriggerProviderListApi,
|
||||
TriggerSubscriptionBuilderBuildApi,
|
||||
TriggerSubscriptionBuilderCreateApi,
|
||||
TriggerSubscriptionBuilderGetApi,
|
||||
TriggerSubscriptionBuilderLogsApi,
|
||||
TriggerSubscriptionBuilderUpdateApi,
|
||||
TriggerSubscriptionBuilderVerifyApi,
|
||||
TriggerSubscriptionDeleteApi,
|
||||
TriggerSubscriptionListApi,
|
||||
TriggerSubscriptionUpdateApi,
|
||||
TriggerSubscriptionVerifyApi,
|
||||
)
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from models.account import Account
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def mock_user():
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "u1"
|
||||
user.current_tenant_id = "t1"
|
||||
return user
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
def test_icon_success(self, app):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon",
|
||||
return_value="icon",
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == "icon"
|
||||
|
||||
def test_list_providers(self, app):
|
||||
api = TriggerProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api) == []
|
||||
|
||||
def test_provider_info(self, app):
|
||||
api = TriggerProviderInfoApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider",
|
||||
return_value={"id": "p1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"id": "p1"}
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
def test_list_success(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == []
|
||||
|
||||
def test_list_invalid_provider(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
result, status = method(api, "bad")
|
||||
assert status == 404
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
def test_create_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "subscription_builder" in result
|
||||
|
||||
def test_get_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderGetApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_verify_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {"a": 1}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"ok": True}
|
||||
|
||||
def test_verify_builder_error(self, app):
|
||||
api = TriggerSubscriptionBuilderVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder",
|
||||
side_effect=Exception("err"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github", "b1")
|
||||
|
||||
def test_update_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "n"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder",
|
||||
return_value={"id": "b1"},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == {"id": "b1"}
|
||||
|
||||
def test_logs(self, app):
|
||||
api = TriggerSubscriptionBuilderLogsApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
log = MagicMock()
|
||||
log.model_dump.return_value = {"a": 1}
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs",
|
||||
return_value=[log],
|
||||
),
|
||||
):
|
||||
assert "logs" in method(api, "github", "b1")
|
||||
|
||||
def test_build(self, app):
|
||||
api = TriggerSubscriptionBuilderBuildApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "b1") == 200
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
def test_update_rename_only(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.UNAUTHORIZED
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_update_not_found(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"name": "x"}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "x")
|
||||
|
||||
def test_update_rebuild(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
sub = MagicMock()
|
||||
sub.provider_id = "github"
|
||||
sub.credential_type = CredentialType.OAUTH2
|
||||
sub.credentials = {}
|
||||
sub.parameters = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id",
|
||||
return_value=sub,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription"
|
||||
),
|
||||
):
|
||||
assert method(api, "s1") == 200
|
||||
|
||||
def test_delete_subscription(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
|
||||
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session_cls.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = method(api, "sub1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_subscription_value_error(self, app):
|
||||
api = TriggerSubscriptionDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
|
||||
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
session_cls.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "sub1")
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
def test_oauth_authorize_success(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder",
|
||||
return_value=MagicMock(id="b1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context",
|
||||
return_value="ctx",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url",
|
||||
return_value=MagicMock(authorization_url="url"),
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_oauth_authorize_no_client(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFoundError):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_forbidden(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_success(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials={"a": 1}, expires_at=1),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder"
|
||||
),
|
||||
):
|
||||
resp = method(api, "github")
|
||||
assert resp.status_code == 302
|
||||
|
||||
def test_oauth_callback_no_oauth_client(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "github")
|
||||
|
||||
def test_oauth_callback_empty_credentials(self, app):
|
||||
api = TriggerOAuthCallbackApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
ctx = {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"subscription_builder_id": "b1",
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context",
|
||||
return_value=ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client",
|
||||
return_value={"a": 1},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials",
|
||||
return_value=MagicMock(credentials=None, expires_at=None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
def test_get_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params",
|
||||
return_value={},
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled",
|
||||
return_value=False,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider",
|
||||
return_value=MagicMock(get_oauth_client_schema=lambda: {}),
|
||||
),
|
||||
):
|
||||
result = method(api, "github")
|
||||
assert "configured" in result
|
||||
|
||||
def test_post_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_delete_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github") == {"ok": True}
|
||||
|
||||
def test_oauth_client_post_value_error(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"enabled": True}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params",
|
||||
side_effect=ValueError("bad"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github")
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
def test_verify_success(self, app):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
assert method(api, "github", "s1") == {"ok": True}
|
||||
|
||||
@pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")])
|
||||
def test_verify_errors(self, app, raised_exception):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={"credentials": {}}),
|
||||
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
|
||||
patch(
|
||||
"controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials",
|
||||
side_effect=raised_exception,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(api, "github", "s1")
|
||||
@@ -1,605 +0,0 @@
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.workspace.workspace import (
|
||||
CustomConfigWorkspaceApi,
|
||||
SwitchWorkspaceApi,
|
||||
TenantApi,
|
||||
TenantListApi,
|
||||
WebappLogoWorkspaceApi,
|
||||
WorkspaceInfoApi,
|
||||
WorkspaceListApi,
|
||||
WorkspacePermissionApi,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from models.account import TenantStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
|
||||
def test_get_billing_disabled(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow())
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=False,
|
||||
total=1,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}),
|
||||
patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["total"] == 1
|
||||
assert result["has_more"] is False
|
||||
|
||||
def test_get_has_next_true(self, app):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="T",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
paginate_result = MagicMock(
|
||||
items=[tenant],
|
||||
has_next=True,
|
||||
total=10,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.paginate",
|
||||
return_value=paginate_result,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["has_more"] is True
|
||||
|
||||
|
||||
class TestTenantApi:
|
||||
def test_post_active_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["id"] == "t1"
|
||||
|
||||
def test_post_archived_with_switch(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
archived = MagicMock(status=TenantStatus.ARCHIVE)
|
||||
new_tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=archived)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert result["id"] == "new"
|
||||
|
||||
def test_post_archived_no_tenant(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
|
||||
):
|
||||
with pytest.raises(Unauthorized):
|
||||
method(api)
|
||||
|
||||
def test_post_info_path(self, app):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/info"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(user, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
warn_mock.assert_called_once()
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestSwitchWorkspaceApi:
|
||||
def test_switch_success(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "t2"}
|
||||
tenant = MagicMock(id="t2")
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"}
|
||||
),
|
||||
):
|
||||
query_mock.return_value.get.return_value = tenant
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_not_linked(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "bad"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
|
||||
):
|
||||
with pytest.raises(AccountNotLinkTenantError):
|
||||
method(api)
|
||||
|
||||
def test_switch_tenant_not_found(self, app):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "missing"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.query") as query_mock,
|
||||
):
|
||||
query_mock.return_value.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestCustomConfigWorkspaceApi:
|
||||
def test_post_success(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={})
|
||||
|
||||
payload = {"remove_webapp_brand": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_logo_fallback(self, app):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
|
||||
|
||||
payload = {"remove_webapp_brand": False}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.get_or_404",
|
||||
return_value=tenant,
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
|
||||
assert result["result"] == "success"
|
||||
|
||||
|
||||
class TestWebappLogoWorkspaceApi:
|
||||
def test_no_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
method(api)
|
||||
|
||||
def test_too_many_files(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": MagicMock(),
|
||||
"extra": MagicMock(),
|
||||
}
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data=data),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(TooManyFilesError):
|
||||
method(api)
|
||||
|
||||
def test_invalid_extension(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = MagicMock(filename="test.txt")
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={"file": file}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
def test_upload_success(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
upload = MagicMock(id="file1")
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.return_value = upload
|
||||
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file1"
|
||||
|
||||
def test_filename_missing(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
method(api)
|
||||
|
||||
def test_file_too_large(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
method(api)
|
||||
|
||||
def test_service_unsupported_file(self, app):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspaceInfoApi:
|
||||
def test_post_success(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
|
||||
payload = {"name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"name": "New Name"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = {"name": "X"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
|
||||
|
||||
class TestWorkspacePermissionApi:
|
||||
def test_get_success(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
permission = MagicMock(
|
||||
workspace_id="t1",
|
||||
allow_member_invite=True,
|
||||
allow_owner_transfer=False,
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
|
||||
return_value=permission,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspace_id"] == "t1"
|
||||
|
||||
def test_no_current_tenant(self, app):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
@@ -1,142 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, permission):
|
||||
self._permission = permission
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._permission
|
||||
|
||||
|
||||
def _workspace_module():
|
||||
return importlib.import_module(plugin_permission_required.__module__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, None)
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Shared fixtures for controllers.web unit tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Minimal Flask app for request contexts."""
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
class FakeSession:
|
||||
"""Stand-in for db.session that returns pre-seeded objects by model class name."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any] | None = None):
|
||||
self._mapping: dict[str, Any] = mapping or {}
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model: type) -> FakeSession:
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *_args: object, **_kwargs: object) -> FakeSession:
|
||||
return self
|
||||
|
||||
def first(self) -> Any:
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: FakeSession | None = None):
|
||||
self.session = session or FakeSession()
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def make_app_model(
|
||||
*,
|
||||
app_id: str = "app-1",
|
||||
tenant_id: str = "tenant-1",
|
||||
mode: str = "chat",
|
||||
enable_site: bool = True,
|
||||
status: str = "normal",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake App model with common defaults."""
|
||||
tenant = SimpleNamespace(
|
||||
id=tenant_id,
|
||||
status="normal",
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
tenant=tenant,
|
||||
mode=mode,
|
||||
enable_site=enable_site,
|
||||
status=status,
|
||||
workflow=None,
|
||||
app_model_config=None,
|
||||
)
|
||||
|
||||
|
||||
def make_end_user(
|
||||
*,
|
||||
user_id: str = "end-user-1",
|
||||
session_id: str = "session-1",
|
||||
external_user_id: str = "ext-user-1",
|
||||
) -> SimpleNamespace:
|
||||
"""Build a fake EndUser model with common defaults."""
|
||||
return SimpleNamespace(
|
||||
id=user_id,
|
||||
session_id=session_id,
|
||||
external_user_id=external_user_id,
|
||||
)
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Unit tests for controllers.web.app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission
|
||||
from controllers.web.error import AppUnavailableError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppParameterApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppParameterApi:
|
||||
def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {"opening_statement": "Hello"}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"}
|
||||
result = AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[])
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
def test_workflow_mode_uses_workflow(self, app: Flask) -> None:
|
||||
features_dict = {}
|
||||
workflow = SimpleNamespace(
|
||||
features_dict=features_dict,
|
||||
user_input_form=lambda to_old_structure=False: [{"var": "x"}],
|
||||
)
|
||||
app_model = SimpleNamespace(mode="workflow", workflow=workflow)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}])
|
||||
|
||||
def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="advanced-chat", workflow=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
def test_standard_mode_uses_app_model_config(self, app: Flask) -> None:
|
||||
config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"})
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=config)
|
||||
|
||||
with (
|
||||
app.test_request_context("/parameters"),
|
||||
patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params,
|
||||
patch("controllers.web.app.fields.Parameters") as mock_fields,
|
||||
):
|
||||
mock_fields.model_validate.return_value.model_dump.return_value = {}
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
call_kwargs = mock_params.call_args
|
||||
assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}]
|
||||
|
||||
def test_standard_mode_no_config_raises(self, app: Flask) -> None:
|
||||
app_model = SimpleNamespace(mode="chat", app_model_config=None)
|
||||
with app.test_request_context("/parameters"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
AppParameterApi().get(app_model, SimpleNamespace())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppMeta
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppMeta:
|
||||
@patch("controllers.web.app.AppService")
|
||||
def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None:
|
||||
mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}}
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
|
||||
with app.test_request_context("/meta"):
|
||||
result = AppMeta().get(app_model, SimpleNamespace())
|
||||
|
||||
assert result == {"tool_icons": {}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppAccessMode
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppAccessMode:
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "public"}
|
||||
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_returns_access_mode_with_app_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="internal")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appId=app-1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
assert result == {"accessMode": "internal"}
|
||||
mock_access.assert_called_once_with("app-1")
|
||||
|
||||
@patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id")
|
||||
@patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_resolves_app_code_to_id(
|
||||
self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_access.return_value = SimpleNamespace(access_mode="external")
|
||||
|
||||
with app.test_request_context("/webapp/access-mode?appCode=code1"):
|
||||
result = AppAccessMode().get()
|
||||
|
||||
mock_resolve.assert_called_once_with("code1")
|
||||
mock_access.assert_called_once_with("resolved-id")
|
||||
assert result == {"accessMode": "external"}
|
||||
|
||||
@patch("controllers.web.app.FeatureService.get_system_features")
|
||||
def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
|
||||
with app.test_request_context("/webapp/access-mode"):
|
||||
with pytest.raises(ValueError, match="appId or appCode"):
|
||||
AppAccessMode().get()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppWebAuthPermission
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppWebAuthPermission:
|
||||
@patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}):
|
||||
result = AppWebAuthPermission().get()
|
||||
|
||||
assert result == {"result": True}
|
||||
|
||||
def test_raises_when_missing_app_id(self, app: Flask) -> None:
|
||||
with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(ValueError, match="appId"):
|
||||
AppWebAuthPermission().get()
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Unit tests for controllers.web.audio endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.audio import AudioApi, TextApi
|
||||
from controllers.web.error import (
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
NoAudioUploadedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
NoAudioUploadedServiceError,
|
||||
ProviderNotSupportSpeechToTextServiceError,
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1", external_user_id="ext-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AudioApi (audio-to-text)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAudioApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"})
|
||||
def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
data = {"file": (BytesIO(b"fake-audio"), "test.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result = AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == {"text": "hello"}
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError())
|
||||
def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b""), "empty.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(NoAudioUploadedError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big"))
|
||||
def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"big"), "big.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError())
|
||||
def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"bad"), "bad.xyz")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(UnsupportedAudioTypeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderNotSupportSpeechToTextServiceError(),
|
||||
)
|
||||
def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotSupportSpeechToTextError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_asr",
|
||||
side_effect=ProviderTokenNotInitError(description="no token"),
|
||||
)
|
||||
def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError())
|
||||
def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError())
|
||||
def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"x"), "x.mp3")}
|
||||
with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
AudioApi().post(_app_model(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TextApi (text-to-audio)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestTextApi:
|
||||
@patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes")
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello", "voice": "alloy"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
result = TextApi().post(_app_model(), _end_user())
|
||||
|
||||
assert result == "audio-bytes"
|
||||
mock_tts.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.audio.AudioService.transcript_tts",
|
||||
side_effect=InvokeError(description="invoke failed"),
|
||||
)
|
||||
@patch("controllers.web.audio.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"text": "hello"}
|
||||
|
||||
with app.test_request_context("/text-to-audio", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
TextApi().post(_app_model(), _end_user())
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Unit tests for controllers.web.completion endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
|
||||
from controllers.web.error import (
|
||||
CompletionRequestError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "test"}
|
||||
mock_gen.return_value = "response-obj"
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
result = CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "hi"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=ModelCurrentlyNotSupportError(),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST"):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
CompletionApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CompletionStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestCompletionStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
CompletionStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/completion-messages/task-1/stop", method="POST"):
|
||||
result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatApi().post(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"})
|
||||
@patch("controllers.web.completion.AppGenerateService.generate")
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "hi"}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
result = ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
assert result == {"answer": "reply"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.completion.AppGenerateService.generate",
|
||||
side_effect=InvokeError(description="rate limit"),
|
||||
)
|
||||
@patch("controllers.web.completion.web_ns")
|
||||
def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}, "query": "x"}
|
||||
|
||||
with app.test_request_context("/chat-messages", method="POST"):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
ChatApi().post(_chat_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChatStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestChatStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ChatStopApi().post(_completion_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.completion.AppTaskService.stop_task")
|
||||
def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/chat-messages/task-1/stop", method="POST"):
|
||||
result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == {"result": "success"}
|
||||
@@ -1,183 +0,0 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
ConversationApi,
|
||||
ConversationListApi,
|
||||
ConversationPinApi,
|
||||
ConversationRenameApi,
|
||||
ConversationUnPinApi,
|
||||
)
|
||||
from controllers.web.error import NotChatAppError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
name="Test",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
id=str(c_id),
|
||||
name="New Name",
|
||||
inputs={},
|
||||
status="normal",
|
||||
introduction="",
|
||||
created_at=1700000000,
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_rename.return_value = conv
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}):
|
||||
result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["name"] == "New Name"
|
||||
|
||||
@patch(
|
||||
"controllers.web.conversation.ConversationService.rename",
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
|
||||
assert result["result"] == "success"
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Unit tests for controllers.web.error HTTP exception classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
AppSuggestedQuestionsAfterAnswerDisabledError,
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
InvalidArgumentError,
|
||||
InvokeRateLimitError,
|
||||
NoAudioUploadedError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
NotFoundError,
|
||||
NotWorkflowAppError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderNotSupportSpeechToTextError,
|
||||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
WebAppAuthAccessDeniedError,
|
||||
WebAppAuthRequiredError,
|
||||
WebFormRateLimitExceededError,
|
||||
)
|
||||
|
||||
_ERROR_SPECS: list[tuple[type, str, int]] = [
|
||||
(AppUnavailableError, "app_unavailable", 400),
|
||||
(NotCompletionAppError, "not_completion_app", 400),
|
||||
(NotChatAppError, "not_chat_app", 400),
|
||||
(NotWorkflowAppError, "not_workflow_app", 400),
|
||||
(ConversationCompletedError, "conversation_completed", 400),
|
||||
(ProviderNotInitializeError, "provider_not_initialize", 400),
|
||||
(ProviderQuotaExceededError, "provider_quota_exceeded", 400),
|
||||
(ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400),
|
||||
(CompletionRequestError, "completion_request_error", 400),
|
||||
(AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403),
|
||||
(AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403),
|
||||
(NoAudioUploadedError, "no_audio_uploaded", 400),
|
||||
(AudioTooLargeError, "audio_too_large", 413),
|
||||
(UnsupportedAudioTypeError, "unsupported_audio_type", 415),
|
||||
(ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400),
|
||||
(WebAppAuthRequiredError, "web_sso_auth_required", 401),
|
||||
(WebAppAuthAccessDeniedError, "web_app_access_denied", 401),
|
||||
(InvokeRateLimitError, "rate_limit_error", 429),
|
||||
(WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429),
|
||||
(NotFoundError, "not_found", 404),
|
||||
(InvalidArgumentError, "invalid_param", 400),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls", "expected_code", "expected_status"),
|
||||
_ERROR_SPECS,
|
||||
ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS],
|
||||
)
|
||||
def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None:
|
||||
"""Each error class exposes the correct error_code and HTTP status code."""
|
||||
assert cls.error_code == expected_code
|
||||
assert cls.code == expected_status
|
||||
|
||||
|
||||
def test_error_classes_have_description() -> None:
|
||||
"""Every error class has a description (string or None for generic errors)."""
|
||||
# NotFoundError and InvalidArgumentError use None description by design
|
||||
_NO_DESCRIPTION = {NotFoundError, InvalidArgumentError}
|
||||
for cls, _, _ in _ERROR_SPECS:
|
||||
if cls in _NO_DESCRIPTION:
|
||||
continue
|
||||
assert isinstance(cls.description, str), f"{cls.__name__} missing description"
|
||||
assert len(cls.description) > 0, f"{cls.__name__} has empty description"
|
||||
@@ -1,38 +0,0 @@
|
||||
"""Unit tests for controllers.web.feature endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.feature import SystemFeatureApi
|
||||
|
||||
|
||||
class TestSystemFeatureApi:
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
with app.test_request_context("/system-features"):
|
||||
result = SystemFeatureApi().get()
|
||||
|
||||
assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}}
|
||||
mock_features.assert_called_once()
|
||||
|
||||
@patch("controllers.web.feature.FeatureService.get_system_features")
|
||||
def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
"""SystemFeatureApi is unauthenticated by design — no WebApiResource decorator."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.model_dump.return_value = {}
|
||||
mock_features.return_value = mock_model
|
||||
|
||||
# Verify it's a bare Resource, not WebApiResource
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
||||
assert issubclass(SystemFeatureApi, Resource)
|
||||
assert not issubclass(SystemFeatureApi, WebApiResource)
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Unit tests for controllers.web.files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
)
|
||||
from controllers.web.files import FileApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
class TestFileApi:
|
||||
def test_no_file_uploaded(self, app: Flask) -> None:
|
||||
with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_too_many_files(self, app: Flask) -> None:
|
||||
data = {
|
||||
"file": (BytesIO(b"a"), "a.txt"),
|
||||
"file2": (BytesIO(b"b"), "b.txt"),
|
||||
}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
# Now has "file" key but len(request.files) > 1
|
||||
with pytest.raises(TooManyFilesError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
def test_filename_missing(self, app: Flask) -> None:
|
||||
data = {"file": (BytesIO(b"content"), "")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="file-1",
|
||||
name="test.txt",
|
||||
size=100,
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
data = {"file": (BytesIO(b"content"), "test.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
result, status = FileApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file-1"
|
||||
assert result["name"] == "test.txt"
|
||||
|
||||
@patch("controllers.web.files.FileService")
|
||||
@patch("controllers.web.files.db")
|
||||
def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None:
|
||||
import services.errors.file
|
||||
|
||||
mock_db.engine = "engine"
|
||||
mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError(
|
||||
description="max 10MB"
|
||||
)
|
||||
|
||||
data = {"file": (BytesIO(b"big"), "big.txt")}
|
||||
with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
FileApi().post(_app_model(), _end_user())
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import (
|
||||
AppMoreLikeThisDisabledError,
|
||||
NotChatAppError,
|
||||
NotCompletionAppError,
|
||||
)
|
||||
from controllers.web.message import (
|
||||
MessageFeedbackApi,
|
||||
MessageMoreLikeThisApi,
|
||||
MessageSuggestedQuestionApi,
|
||||
)
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageFeedbackApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageFeedbackApi:
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "like", "content": "great"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch("controllers.web.message.MessageService.create_feedback")
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": None}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.create_feedback",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.message.web_ns")
|
||||
def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"rating": "dislike"}
|
||||
msg_id = uuid4()
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageMoreLikeThisApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageMoreLikeThisApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"})
|
||||
@patch("controllers.web.message.AppGenerateService.generate_more_like_this")
|
||||
def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert result == {"answer": "similar"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.AppGenerateService.generate_more_like_this",
|
||||
side_effect=MoreLikeThisDisabledError(),
|
||||
)
|
||||
def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"):
|
||||
with pytest.raises(AppMoreLikeThisDisabledError):
|
||||
MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageSuggestedQuestionApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestMessageSuggestedQuestionApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.message.MessageService.get_suggested_questions_after_answer")
|
||||
def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
mock_suggest.return_value = ["What about X?", "Tell me more about Y."]
|
||||
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
assert result["data"] == ["What about X?", "Tell me more about Y."]
|
||||
|
||||
@patch(
|
||||
"controllers.web.message.MessageService.get_suggested_questions_after_answer",
|
||||
side_effect=MessageNotExistsError(),
|
||||
)
|
||||
def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/messages/{msg_id}/suggested-questions"):
|
||||
with pytest.raises(NotFound, match="Message not found"):
|
||||
MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id)
|
||||
@@ -1,103 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportService,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_none() -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"})
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_enterprise_webapp_user_id("token")
|
||||
|
||||
|
||||
def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
decoded = {"token_source": "webapp_login_token", "user_id": "u1"}
|
||||
monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded)
|
||||
assert decode_enterprise_webapp_user_id("token") == decoded
|
||||
|
||||
|
||||
def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp")
|
||||
|
||||
decoded = {"auth_type": "public"}
|
||||
result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC)
|
||||
assert result == "resp"
|
||||
|
||||
|
||||
def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True)
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
return site if _scalar_side_effect.calls == 1 else app_model
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL)
|
||||
|
||||
|
||||
def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal")
|
||||
app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1")
|
||||
|
||||
def _scalar_side_effect(*_args, **_kwargs):
|
||||
if not hasattr(_scalar_side_effect, "calls"):
|
||||
_scalar_side_effect.calls = 0
|
||||
_scalar_side_effect.calls += 1
|
||||
if _scalar_side_effect.calls == 1:
|
||||
return site
|
||||
if _scalar_side_effect.calls == 2:
|
||||
return app_model
|
||||
return None
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
decoded = {"auth_type": "internal"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL)
|
||||
|
||||
|
||||
def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
counts = [1, 0]
|
||||
|
||||
def _scalar(*_args, **_kwargs):
|
||||
return counts.pop(0)
|
||||
|
||||
db_session = SimpleNamespace(scalar=_scalar)
|
||||
monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session))
|
||||
|
||||
session_id = generate_session_id()
|
||||
assert session_id
|
||||
@@ -1,423 +0,0 @@
|
||||
"""Unit tests for Pydantic models defined in controllers.web modules.
|
||||
|
||||
Covers validation logic, field defaults, constraints, and custom validators
|
||||
for all ~15 Pydantic models across the web controller layer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.app import AppAccessModeQuery
|
||||
|
||||
|
||||
class TestAppAccessModeQuery:
|
||||
def test_alias_resolution(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"})
|
||||
assert q.app_id == "abc"
|
||||
assert q.app_code == "xyz"
|
||||
|
||||
def test_defaults_to_none(self) -> None:
|
||||
q = AppAccessModeQuery.model_validate({})
|
||||
assert q.app_id is None
|
||||
assert q.app_code is None
|
||||
|
||||
def test_accepts_snake_case(self) -> None:
|
||||
q = AppAccessModeQuery(app_id="id1", app_code="code1")
|
||||
assert q.app_id == "id1"
|
||||
assert q.app_code == "code1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# audio.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.audio import TextToAudioPayload
|
||||
|
||||
|
||||
class TestTextToAudioPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = TextToAudioPayload.model_validate({})
|
||||
assert p.message_id is None
|
||||
assert p.voice is None
|
||||
assert p.text is None
|
||||
assert p.streaming is None
|
||||
|
||||
def test_valid_uuid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = TextToAudioPayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_none_message_id_passthrough(self) -> None:
|
||||
p = TextToAudioPayload(message_id=None)
|
||||
assert p.message_id is None
|
||||
|
||||
def test_invalid_uuid_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
TextToAudioPayload(message_id="not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# completion.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload
|
||||
|
||||
|
||||
class TestCompletionMessagePayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = CompletionMessagePayload(inputs={})
|
||||
assert p.query == ""
|
||||
assert p.files is None
|
||||
assert p.response_mode is None
|
||||
assert p.retriever_from == "web_app"
|
||||
|
||||
def test_accepts_full_payload(self) -> None:
|
||||
p = CompletionMessagePayload(
|
||||
inputs={"key": "val"},
|
||||
query="test",
|
||||
files=[{"id": "f1"}],
|
||||
response_mode="streaming",
|
||||
)
|
||||
assert p.response_mode == "streaming"
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_invalid_response_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
CompletionMessagePayload(inputs={}, response_mode="invalid")
|
||||
|
||||
|
||||
class TestChatMessagePayload:
|
||||
def test_valid_uuid_fields(self) -> None:
|
||||
cid = str(uuid4())
|
||||
pid = str(uuid4())
|
||||
p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid)
|
||||
assert p.conversation_id == cid
|
||||
assert p.parent_message_id == pid
|
||||
|
||||
def test_none_uuid_fields(self) -> None:
|
||||
p = ChatMessagePayload(inputs={}, query="hi")
|
||||
assert p.conversation_id is None
|
||||
assert p.parent_message_id is None
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", conversation_id="bad")
|
||||
|
||||
def test_invalid_parent_message_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad")
|
||||
|
||||
def test_query_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ChatMessagePayload(inputs={})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conversation.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = ConversationListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
assert q.pinned is None
|
||||
assert q.sort_by == "-updated_at"
|
||||
|
||||
def test_limit_lower_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=0)
|
||||
|
||||
def test_limit_upper_bound(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(limit=101)
|
||||
|
||||
def test_limit_boundaries_valid(self) -> None:
|
||||
assert ConversationListQuery(limit=1).limit == 1
|
||||
assert ConversationListQuery(limit=100).limit == 100
|
||||
|
||||
def test_valid_sort_by_options(self) -> None:
|
||||
for opt in ("created_at", "-created_at", "updated_at", "-updated_at"):
|
||||
assert ConversationListQuery(sort_by=opt).sort_by == opt
|
||||
|
||||
def test_invalid_sort_by(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ConversationListQuery(sort_by="invalid")
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
assert ConversationListQuery(last_id=uid).last_id == uid
|
||||
|
||||
def test_invalid_last_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
ConversationListQuery(last_id="not-uuid")
|
||||
|
||||
|
||||
class TestConversationRenamePayload:
|
||||
def test_auto_generate_true_no_name_required(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=True)
|
||||
assert p.name is None
|
||||
|
||||
def test_auto_generate_false_requires_name(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False)
|
||||
|
||||
def test_auto_generate_false_blank_name_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="name is required"):
|
||||
ConversationRenamePayload(auto_generate=False, name=" ")
|
||||
|
||||
def test_auto_generate_false_with_valid_name(self) -> None:
|
||||
p = ConversationRenamePayload(auto_generate=False, name="My Chat")
|
||||
assert p.name == "My Chat"
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
p = ConversationRenamePayload(name="test")
|
||||
assert p.auto_generate is False
|
||||
assert p.name == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
def test_valid_query(self) -> None:
|
||||
cid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid)
|
||||
assert q.conversation_id == cid
|
||||
assert q.first_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_invalid_conversation_id(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id="bad")
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
MessageListQuery(conversation_id=cid, limit=101)
|
||||
|
||||
def test_valid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
fid = str(uuid4())
|
||||
q = MessageListQuery(conversation_id=cid, first_id=fid)
|
||||
assert q.first_id == fid
|
||||
|
||||
def test_invalid_first_id(self) -> None:
|
||||
cid = str(uuid4())
|
||||
with pytest.raises(ValidationError, match="not a valid uuid"):
|
||||
MessageListQuery(conversation_id=cid, first_id="invalid")
|
||||
|
||||
|
||||
class TestMessageFeedbackPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = MessageFeedbackPayload()
|
||||
assert p.rating is None
|
||||
assert p.content is None
|
||||
|
||||
def test_valid_ratings(self) -> None:
|
||||
assert MessageFeedbackPayload(rating="like").rating == "like"
|
||||
assert MessageFeedbackPayload(rating="dislike").rating == "dislike"
|
||||
|
||||
def test_invalid_rating(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageFeedbackPayload(rating="neutral")
|
||||
|
||||
|
||||
class TestMessageMoreLikeThisQuery:
|
||||
def test_valid_modes(self) -> None:
|
||||
assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking"
|
||||
assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming"
|
||||
|
||||
def test_invalid_mode(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery(response_mode="invalid")
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
MessageMoreLikeThisQuery()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remote_files.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.remote_files import RemoteFileUploadPayload
|
||||
|
||||
|
||||
class TestRemoteFileUploadPayload:
|
||||
def test_valid_url(self) -> None:
|
||||
p = RemoteFileUploadPayload(url="https://example.com/file.pdf")
|
||||
assert str(p.url) == "https://example.com/file.pdf"
|
||||
|
||||
def test_invalid_url(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload(url="not-a-url")
|
||||
|
||||
def test_url_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
RemoteFileUploadPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# saved_message.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery
|
||||
|
||||
|
||||
class TestSavedMessageListQuery:
|
||||
def test_defaults(self) -> None:
|
||||
q = SavedMessageListQuery()
|
||||
assert q.last_id is None
|
||||
assert q.limit == 20
|
||||
|
||||
def test_limit_bounds(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageListQuery(limit=101)
|
||||
|
||||
def test_valid_last_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
q = SavedMessageListQuery(last_id=uid)
|
||||
assert q.last_id == uid
|
||||
|
||||
def test_empty_last_id(self) -> None:
|
||||
q = SavedMessageListQuery(last_id="")
|
||||
assert q.last_id == ""
|
||||
|
||||
|
||||
class TestSavedMessageCreatePayload:
|
||||
def test_valid_message_id(self) -> None:
|
||||
uid = str(uuid4())
|
||||
p = SavedMessageCreatePayload(message_id=uid)
|
||||
assert p.message_id == uid
|
||||
|
||||
def test_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
SavedMessageCreatePayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# workflow.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.workflow import WorkflowRunPayload
|
||||
|
||||
|
||||
class TestWorkflowRunPayload:
|
||||
def test_defaults(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={})
|
||||
assert p.inputs == {}
|
||||
assert p.files is None
|
||||
|
||||
def test_with_files(self) -> None:
|
||||
p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}])
|
||||
assert p.files == [{"id": "f1"}]
|
||||
|
||||
def test_inputs_required(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowRunPayload()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# forgot_password.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckPayload,
|
||||
ForgotPasswordResetPayload,
|
||||
ForgotPasswordSendPayload,
|
||||
)
|
||||
|
||||
|
||||
class TestForgotPasswordSendPayload:
|
||||
def test_valid_email(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="user@example.com")
|
||||
assert p.email == "user@example.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
ForgotPasswordSendPayload(email="not-an-email")
|
||||
|
||||
def test_language_optional(self) -> None:
|
||||
p = ForgotPasswordSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
|
||||
class TestForgotPasswordCheckPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.email == "a@b.com"
|
||||
assert p.code == "1234"
|
||||
assert p.token == "tok"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="")
|
||||
|
||||
|
||||
class TestForgotPasswordResetPayload:
|
||||
def test_valid_passwords(self) -> None:
|
||||
p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234")
|
||||
assert p.new_password == "Valid1234"
|
||||
|
||||
def test_weak_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short")
|
||||
|
||||
def test_letters_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi")
|
||||
|
||||
def test_digits_only_password_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# login.py models
|
||||
# ---------------------------------------------------------------------------
|
||||
from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload
|
||||
|
||||
|
||||
class TestLoginPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = LoginPayload(email="a@b.com", password="Valid1234")
|
||||
assert p.email == "a@b.com"
|
||||
|
||||
def test_invalid_email(self) -> None:
|
||||
with pytest.raises(ValidationError, match="not a valid email"):
|
||||
LoginPayload(email="bad", password="Valid1234")
|
||||
|
||||
def test_weak_password(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Password must contain"):
|
||||
LoginPayload(email="a@b.com", password="weak")
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com")
|
||||
assert p.language is None
|
||||
|
||||
def test_with_language(self) -> None:
|
||||
p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans")
|
||||
assert p.language == "zh-Hans"
|
||||
|
||||
|
||||
class TestEmailCodeLoginVerifyPayload:
|
||||
def test_valid(self) -> None:
|
||||
p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok")
|
||||
assert p.code == "1234"
|
||||
|
||||
def test_empty_token_rejected(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="")
|
||||
@@ -1,147 +0,0 @@
|
||||
"""Unit tests for controllers.web.remote_files endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError
|
||||
from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi
|
||||
|
||||
|
||||
def _app_model() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileInfoApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileInfoApi:
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"}
|
||||
mock_proxy.head.return_value = mock_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf")
|
||||
|
||||
assert result["file_type"] == "application/pdf"
|
||||
assert result["file_length"] == 1024
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 405 # Method not allowed
|
||||
get_resp = MagicMock()
|
||||
get_resp.status_code = 200
|
||||
get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"}
|
||||
get_resp.raise_for_status = MagicMock()
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"):
|
||||
result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt")
|
||||
|
||||
assert result["file_type"] == "text/plain"
|
||||
mock_proxy.get.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RemoteFileUploadApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestRemoteFileUploadApi:
|
||||
@patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url")
|
||||
@patch("controllers.web.remote_files.FileService")
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
@patch("controllers.web.remote_files.db")
|
||||
def test_upload_success(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_file_svc_cls: MagicMock,
|
||||
mock_signed: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_ns.payload = {"url": "https://example.com/file.pdf"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
head_resp.content = b"pdf-content"
|
||||
head_resp.request.method = "HEAD"
|
||||
mock_proxy.head.return_value = head_resp
|
||||
get_resp = MagicMock()
|
||||
get_resp.content = b"pdf-content"
|
||||
mock_proxy.get.return_value = get_resp
|
||||
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100
|
||||
)
|
||||
mock_file_svc_cls.is_file_size_within_limit.return_value = True
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
upload_file = SimpleNamespace(
|
||||
id="f-1",
|
||||
name="file.pdf",
|
||||
size=100,
|
||||
extension="pdf",
|
||||
mime_type="application/pdf",
|
||||
created_by="eu-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_file_svc_cls.return_value.upload_file.return_value = upload_file
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
result, status = RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "f-1"
|
||||
|
||||
@patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False)
|
||||
@patch("controllers.web.remote_files.helpers.guess_file_info_from_response")
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_file_too_large(
|
||||
self,
|
||||
mock_ns: MagicMock,
|
||||
mock_proxy: MagicMock,
|
||||
mock_guess: MagicMock,
|
||||
mock_size_check: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_ns.payload = {"url": "https://example.com/big.zip"}
|
||||
head_resp = MagicMock()
|
||||
head_resp.status_code = 200
|
||||
mock_proxy.head.return_value = head_resp
|
||||
mock_guess.return_value = SimpleNamespace(
|
||||
filename="big.zip", extension="zip", mimetype="application/zip", size=999999999
|
||||
)
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
|
||||
@patch("controllers.web.remote_files.ssrf_proxy")
|
||||
@patch("controllers.web.remote_files.web_ns")
|
||||
def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None:
|
||||
import httpx
|
||||
|
||||
mock_ns.payload = {"url": "https://example.com/bad"}
|
||||
mock_proxy.head.side_effect = httpx.RequestError("connection failed")
|
||||
|
||||
with app.test_request_context("/remote-files/upload", method="POST"):
|
||||
with pytest.raises(RemoteFileUploadError):
|
||||
RemoteFileUploadApi().post(_app_model(), _end_user())
|
||||
@@ -1,97 +0,0 @@
|
||||
"""Unit tests for controllers.web.saved_message endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.error import NotCompletionAppError
|
||||
from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
def _completion_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="completion")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (GET)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiGet:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().get(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id")
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[])
|
||||
|
||||
with app.test_request_context("/saved-messages?limit=20"):
|
||||
result = SavedMessageListApi().get(_completion_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageListApi (POST)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageListApiPost:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageListApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save")
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
msg_id = str(uuid4())
|
||||
mock_ns.payload = {"message_id": msg_id}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
result = SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError())
|
||||
@patch("controllers.web.saved_message.web_ns")
|
||||
def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"message_id": str(uuid4())}
|
||||
|
||||
with app.test_request_context("/saved-messages", method="POST"):
|
||||
with pytest.raises(NotFound, match="Message Not Exists"):
|
||||
SavedMessageListApi().post(_completion_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SavedMessageApi (DELETE)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestSavedMessageApi:
|
||||
def test_non_completion_mode_raises(self, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
SavedMessageApi().delete(_chat_app(), _end_user(), msg_id)
|
||||
|
||||
@patch("controllers.web.saved_message.SavedMessageService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
msg_id = uuid4()
|
||||
with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"):
|
||||
result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id)
|
||||
|
||||
assert status == 204
|
||||
assert result["result"] == "success"
|
||||
@@ -1,126 +0,0 @@
|
||||
"""Unit tests for controllers.web.site endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.web.site import AppSiteApi, AppSiteInfo
|
||||
|
||||
|
||||
def _tenant(*, status: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=status,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False},
|
||||
)
|
||||
|
||||
|
||||
def _site() -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
title="Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteApi:
|
||||
@patch("controllers.web.site.FeatureService.get_features")
|
||||
@patch("controllers.web.site.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_features.return_value = SimpleNamespace(can_replace_logo=False)
|
||||
site_obj = _site()
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = site_obj
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
result = AppSiteApi().get(app_model, end_user)
|
||||
|
||||
# marshal_with serializes AppSiteInfo to a dict
|
||||
assert result["app_id"] == "app-1"
|
||||
assert result["plan"] == "basic"
|
||||
assert result["enable_site"] is True
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
tenant = _tenant()
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
@patch("controllers.web.site.db")
|
||||
def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None:
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
from models.account import TenantStatus
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = _site()
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.ARCHIVE,
|
||||
plan="basic",
|
||||
custom_config_dict={},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
end_user = SimpleNamespace(id="eu-1")
|
||||
|
||||
with app.test_request_context("/site"):
|
||||
with pytest.raises(Forbidden):
|
||||
AppSiteApi().get(app_model, end_user)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AppSiteInfo
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestAppSiteInfo:
|
||||
def test_basic_fields(self) -> None:
|
||||
tenant = _tenant()
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False)
|
||||
|
||||
assert info.app_id == "app-1"
|
||||
assert info.end_user_id == "eu-1"
|
||||
assert info.enable_site is True
|
||||
assert info.plan == "basic"
|
||||
assert info.can_replace_logo is False
|
||||
assert info.model_config is None
|
||||
|
||||
@patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com"))
|
||||
def test_can_replace_logo_sets_custom_config(self) -> None:
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
plan="pro",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True},
|
||||
)
|
||||
site_obj = _site()
|
||||
info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True)
|
||||
|
||||
assert info.can_replace_logo is True
|
||||
assert info.custom_config["remove_webapp_brand"] is True
|
||||
assert "webapp-logo" in info.custom_config["replace_webapp_logo"]
|
||||
@@ -5,8 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
import services.errors.account
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi
|
||||
from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
|
||||
|
||||
def encode_code(code: str) -> str:
|
||||
@@ -90,114 +89,3 @@ class TestEmailCodeLoginApi:
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_login.assert_called_once()
|
||||
mock_reset_login_rate.assert_called_once_with("user@example.com")
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
@patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok")
|
||||
@patch("controllers.web.login.WebAppAuthService.authenticate")
|
||||
def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None:
|
||||
mock_auth.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
response = LoginApi().post()
|
||||
|
||||
assert response.get_json()["data"]["access_token"] == "access-tok"
|
||||
mock_auth.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountLoginError(),
|
||||
)
|
||||
def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.error import AccountBannedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AccountBannedError):
|
||||
LoginApi().post()
|
||||
|
||||
@patch(
|
||||
"controllers.web.login.WebAppAuthService.authenticate",
|
||||
side_effect=services.errors.account.AccountPasswordError(),
|
||||
)
|
||||
def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None:
|
||||
from controllers.console.auth.error import AuthenticationFailedError
|
||||
|
||||
with app.test_request_context(
|
||||
"/web/login",
|
||||
method="POST",
|
||||
json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()},
|
||||
):
|
||||
with pytest.raises(AuthenticationFailedError):
|
||||
LoginApi().post()
|
||||
|
||||
|
||||
class TestLoginStatusApi:
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value=None)
|
||||
def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/login/status"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token")
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_public_app_user_logged_in(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_decode.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is True
|
||||
assert result["app_logged_in"] is True
|
||||
|
||||
@patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad"))
|
||||
@patch("controllers.web.login.PassportService")
|
||||
@patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
@patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1")
|
||||
@patch("controllers.web.login.extract_webapp_access_token", return_value="tok")
|
||||
def test_private_app_passport_fails(
|
||||
self,
|
||||
mock_extract: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_perm: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_decode: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_passport_cls.return_value.verify.side_effect = Exception("bad")
|
||||
|
||||
with app.test_request_context("/web/login/status?app_code=code1"):
|
||||
result = LoginStatusApi().get()
|
||||
|
||||
assert result["logged_in"] is False
|
||||
assert result["app_logged_in"] is False
|
||||
|
||||
|
||||
class TestLogoutApi:
|
||||
@patch("controllers.web.login.clear_webapp_access_token_from_cookie")
|
||||
def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/web/logout", method="POST"):
|
||||
response = LogoutApi().post()
|
||||
|
||||
assert response.get_json() == {"result": "success"}
|
||||
mock_clear.assert_called_once()
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from controllers.web.passport import (
|
||||
PassportResource,
|
||||
decode_enterprise_webapp_user_id,
|
||||
exchange_token_for_existing_web_user,
|
||||
generate_session_id,
|
||||
)
|
||||
from services.webapp_auth_service import WebAppAuthType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_enterprise_webapp_user_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeEnterpriseWebappUserId:
|
||||
def test_none_token_returns_none(self) -> None:
|
||||
assert decode_enterprise_webapp_user_id(None) is None
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "webapp_login_token",
|
||||
"user_id": "u1",
|
||||
}
|
||||
result = decode_enterprise_webapp_user_id("valid-jwt")
|
||||
assert result["user_id"] == "u1"
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"token_source": "other_source",
|
||||
}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("bad-jwt")
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None:
|
||||
mock_passport_cls.return_value.verify.return_value = {}
|
||||
with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"):
|
||||
decode_enterprise_webapp_user_id("no-source-jwt")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_session_id
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestGenerateSessionId:
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_returns_unique_session_id(self, mock_db: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = 0
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert len(sid) == 36 # UUID format
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_retries_on_collision(self, mock_db: MagicMock) -> None:
|
||||
# First call returns count=1 (collision), second returns 0
|
||||
mock_db.session.scalar.side_effect = [1, 0]
|
||||
sid = generate_session_id()
|
||||
assert isinstance(sid, str)
|
||||
assert mock_db.session.scalar.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# exchange_token_for_existing_web_user
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestExchangeTokenForExistingWebUser:
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="external"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
site = SimpleNamespace(code="code1", app_id="app-1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
|
||||
decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal"
|
||||
with pytest.raises(WebAppAuthRequiredError, match="internal"):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL
|
||||
)
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None:
|
||||
mock_db.session.scalar.return_value = None
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
with pytest.raises(NotFound):
|
||||
exchange_token_for_existing_web_user(
|
||||
app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PassportResource.get
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPassportResource:
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
with app.test_request_context("/passport"):
|
||||
with pytest.raises(Unauthorized, match="X-App-Code"):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.generate_session_id", return_value="new-sess-id")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_creates_new_end_user_when_no_user_id(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_gen_session: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
mock_db.session.scalar.side_effect = [site, app_model]
|
||||
mock_passport_cls.return_value.issue.return_value = "issued-token"
|
||||
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "issued-token"
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("controllers.web.passport.PassportService")
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_reuses_existing_end_user_when_user_id_provided(
|
||||
self,
|
||||
mock_features: MagicMock,
|
||||
mock_db: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1")
|
||||
existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing")
|
||||
mock_db.session.scalar.side_effect = [site, app_model, existing_user]
|
||||
mock_passport_cls.return_value.issue.return_value = "reused-token"
|
||||
|
||||
with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}):
|
||||
response = PassportResource().get()
|
||||
|
||||
assert response.get_json()["access_token"] == "reused-token"
|
||||
# Should not create a new end user
|
||||
mock_db.session.add.assert_not_called()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_db.session.scalar.return_value = None
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
|
||||
@patch("controllers.web.passport.db")
|
||||
@patch("controllers.web.passport.FeatureService.get_system_features")
|
||||
def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
site = SimpleNamespace(app_id="app-1", code="code1")
|
||||
disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False)
|
||||
mock_db.session.scalar.side_effect = [site, disabled_app]
|
||||
with app.test_request_context("/passport", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
PassportResource().get()
|
||||
@@ -1,95 +0,0 @@
|
||||
"""Unit tests for controllers.web.workflow endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import (
|
||||
NotWorkflowAppError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi
|
||||
from core.errors.error import ProviderTokenNotInitError, QuotaExceededError
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="workflow")
|
||||
|
||||
|
||||
def _chat_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", mode="chat")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowRunApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowRunApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowRunApi().post(_chat_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"})
|
||||
@patch("controllers.web.workflow.AppGenerateService.generate")
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {"key": "val"}}
|
||||
mock_gen.return_value = "response"
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
result = WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=ProviderTokenNotInitError(description="not init"),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
@patch(
|
||||
"controllers.web.workflow.AppGenerateService.generate",
|
||||
side_effect=QuotaExceededError(),
|
||||
)
|
||||
@patch("controllers.web.workflow.web_ns")
|
||||
def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None:
|
||||
mock_ns.payload = {"inputs": {}}
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST"):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
WorkflowRunApi().post(_workflow_app(), _end_user())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTaskStopApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode_raises(self, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1")
|
||||
|
||||
@patch("controllers.web.workflow.GraphEngineManager.send_stop_command")
|
||||
@patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check")
|
||||
def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1")
|
||||
|
||||
assert result == {"result": "success"}
|
||||
mock_legacy.assert_called_once_with("task-1")
|
||||
mock_graph.assert_called_once_with("task-1")
|
||||
@@ -1,127 +0,0 @@
|
||||
"""Unit tests for controllers.web.workflow_events endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def _workflow_app() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow")
|
||||
|
||||
|
||||
def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowEventsApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestWorkflowEventsApi:
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="other-app",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_not_created_by_end_user(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by="eu-1",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None:
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="other-user",
|
||||
finished_at=None,
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
with pytest.raises(NotFoundError):
|
||||
WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
@patch("controllers.web.workflow_events.WorkflowResponseConverter")
|
||||
@patch("controllers.web.workflow_events.DifyAPIRepositoryFactory")
|
||||
@patch("controllers.web.workflow_events.db")
|
||||
def test_finished_run_returns_sse_response(
|
||||
self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask
|
||||
) -> None:
|
||||
from datetime import datetime
|
||||
|
||||
mock_db.engine = "engine"
|
||||
run = SimpleNamespace(
|
||||
id="run-1",
|
||||
app_id="app-1",
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="eu-1",
|
||||
finished_at=datetime(2024, 1, 1),
|
||||
)
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_repo
|
||||
|
||||
finish_response = MagicMock()
|
||||
finish_response.model_dump.return_value = {"task_id": "run-1"}
|
||||
finish_response.event.value = "workflow_finished"
|
||||
mock_converter.workflow_run_result_to_finish_response.return_value = finish_response
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events"):
|
||||
response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1")
|
||||
|
||||
assert response.mimetype == "text/event-stream"
|
||||
@@ -1,393 +0,0 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from controllers.web.wraps import (
|
||||
_validate_user_accessibility,
|
||||
_validate_webapp_token,
|
||||
decode_jwt_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "webapp"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_non_webapp_source(self) -> None:
|
||||
decoded = {"token_source": "other"}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_public_app_accepts_no_source(self) -> None:
|
||||
decoded = {}
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
app_web_auth_enabled=False,
|
||||
system_webapp_auth_enabled=False,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_user_id_raises(self) -> None:
|
||||
decoded = {}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=SimpleNamespace(access_mode="internal"),
|
||||
)
|
||||
|
||||
def test_missing_webapp_settings_raises(self) -> None:
|
||||
decoded = {"user_id": "u1"}
|
||||
with pytest.raises(WebAppAuthRequiredError, match="settings not found"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=None,
|
||||
)
|
||||
|
||||
def test_missing_auth_type_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "granted_at": 1}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
def test_missing_granted_at_raises(self) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external"}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_internal_auth_type_checks_workspace_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock
|
||||
) -> None:
|
||||
mock_workspace_sso.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False)
|
||||
def test_external_auth_passes_when_granted_after_sso_update(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2)
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False)
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1")
|
||||
@patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True)
|
||||
def test_permission_check_denies_unauthorized_user(
|
||||
self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock
|
||||
) -> None:
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())}
|
||||
settings = SimpleNamespace(access_mode="internal")
|
||||
with pytest.raises(WebAppAuthAccessDeniedError):
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
app_web_auth_enabled=True,
|
||||
system_webapp_auth_enabled=True,
|
||||
webapp_settings=settings,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@patch("controllers.web.wraps.AppService.get_app_id_by_code")
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
mock_app_id: MagicMock,
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
) -> None:
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@@ -269,6 +269,7 @@ class TestPluginLoading:
|
||||
id="task-123",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin/1.0.0",
|
||||
status=PluginInstallTaskStatus.Running,
|
||||
total_plugins=3,
|
||||
completed_plugins=1,
|
||||
@@ -720,6 +721,7 @@ class TestPluginTaskManagement:
|
||||
id="task-1",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin-a/1.0.0",
|
||||
status=PluginInstallTaskStatus.Running,
|
||||
total_plugins=2,
|
||||
completed_plugins=1,
|
||||
@@ -729,6 +731,7 @@ class TestPluginTaskManagement:
|
||||
id="task-2",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin-b/1.0.0",
|
||||
status=PluginInstallTaskStatus.Success,
|
||||
total_plugins=1,
|
||||
completed_plugins=1,
|
||||
@@ -1256,6 +1259,7 @@ class TestPluginTaskStatusTransitions:
|
||||
id="pending-task",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin/1.0.0",
|
||||
status=PluginInstallTaskStatus.Pending,
|
||||
total_plugins=3,
|
||||
completed_plugins=0, # No plugins completed yet
|
||||
@@ -1283,6 +1287,7 @@ class TestPluginTaskStatusTransitions:
|
||||
id="running-task",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin/1.0.0",
|
||||
status=PluginInstallTaskStatus.Running,
|
||||
total_plugins=5,
|
||||
completed_plugins=2, # 2 out of 5 completed
|
||||
@@ -1311,6 +1316,7 @@ class TestPluginTaskStatusTransitions:
|
||||
id="success-task",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin/1.0.0",
|
||||
status=PluginInstallTaskStatus.Success,
|
||||
total_plugins=4,
|
||||
completed_plugins=4, # All plugins completed
|
||||
@@ -1338,6 +1344,7 @@ class TestPluginTaskStatusTransitions:
|
||||
id="failed-task",
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
plugin_unique_identifier="test-org/test-plugin/1.0.0",
|
||||
status=PluginInstallTaskStatus.Failed,
|
||||
total_plugins=3,
|
||||
completed_plugins=1, # Only 1 completed before failure
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user