Compare commits

...

30 Commits

Author SHA1 Message Date
rajatagarwal-oss
a480e9beb1 test: unit test case for controllers.console.app module (#32247)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
2026-03-10 00:54:47 +08:00
Stephen Zhou
a59c54b3e7 ci: update actions version, reuse workflow by composite action (#33177) 2026-03-09 23:44:17 +08:00
dependabot[bot]
7737bdc699 chore(deps): bump the npm-dependencies group across 1 directory with 55 updates (#33170)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 23:24:48 +08:00
dependabot[bot]
65637fc6b7 chore(deps): bump the npm-dependencies group across 1 directory with 55 updates (#33170)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 23:24:36 +08:00
dependabot[bot]
be6f7b8712 chore(deps-dev): bump the eslint-group group in /web with 5 updates (#33168)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-03-09 19:59:44 +08:00
dependabot[bot]
b257e8ed44 chore(deps-dev): bump the storybook group in /web with 7 updates (#33163) 2026-03-09 19:36:00 +08:00
Stephen Zhou
176d3c8c3a ci: ignore ky and tailwind-merge in update (#33167) 2026-03-09 19:21:18 +08:00
Stephen Zhou
c72ac8a434 ci: ignore some major update (#33161) 2026-03-09 17:24:56 +08:00
rajatagarwal-oss
497feac48e test: unit test case for controllers.console.workspace module (#32181)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-09 17:07:40 +08:00
rajatagarwal-oss
8906ab8e52 test: unit test cases for console.datasets module (#32179)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
2026-03-09 17:07:13 +08:00
非法操作
03dcbeafdf fix: stop responding icon not display (#33154)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-03-09 16:27:45 +08:00
wangxiaolei
bbfa28e8a7 refactor: file saver decouple db engine and ssrf proxy (#33076)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 16:09:44 +08:00
Dev Sharma
6c19e75969 test: improve unit tests for controllers.web (#32150)
Co-authored-by: Rajat Agarwal <rajat.agarwal@infocusp.com>
2026-03-09 15:58:34 +08:00
wangxiaolei
9970f4449a refactor: reuse redis connection instead of create new one (#32678)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 15:53:21 +08:00
Sense_wang
cbb19cce39 docs: use docker compose command consistently in README (#33077)
Co-authored-by: Contributor <contributor@example.com>
2026-03-09 15:02:30 +08:00
hj24
0aef09d630 feat: support relative mode for message clean command (#32834) 2026-03-09 14:32:35 +08:00
wangxiaolei
d2208ad43e fix: fix allow handle value is none (#33031) 2026-03-09 14:20:44 +08:00
非法操作
4a2ba058bb feat: when copy/paste multi nodes not require reconnect them (#32631) 2026-03-09 13:55:12 +08:00
非法操作
654e41d47f fix: workflow_as_tool not work with json input (#32554) 2026-03-09 13:54:54 +08:00
非法操作
ec5409756e feat: keep connections when change node (#31982) 2026-03-09 13:54:10 +08:00
Olexandr88
8b1ea3a8f5 refactor: deduplicate legacy section mapping in ConfigHelper (#32715) 2026-03-09 13:43:06 +08:00
yyh
f2d3feca66 fix(web): fix tool item text not vertically centered in block selector (#33148) 2026-03-09 13:38:11 +08:00
yyh
0590b09958 feat(web): add context menu primitive and dropdown link item (#33125) 2026-03-09 12:05:38 +08:00
wangxiaolei
66f9fde2fe fix: fix metadata filter condition not extract from {{}} (#33141)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 11:51:08 +08:00
Stephen Zhou
1811a855ab chore: update vinext, agentation, remove Prism in lexical (#33142) 2026-03-09 11:40:04 +08:00
Jiaquan Yi
322cd37de1 fix: handle backslash path separators in DOCX ZIP entries exported on…(#33129) (#33131)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-09 10:49:42 +08:00
rajatagarwal-oss
2cc0de9c1b test: unit test case for controllers.common module (#32056) 2026-03-09 09:45:42 +08:00
wangxiaolei
46098b2be6 refactor: use thread.Timer instead of time.sleep (#33121) 2026-03-09 09:38:16 +08:00
akashseth-ifp
7dcf94f48f test: remaining header component and increase branch coverage (#33052)
Co-authored-by: sahil <sahil@infocusp.com>
2026-03-09 09:18:11 +08:00
yyh
7869551afd fix(web): stabilize dayjs timezone tests against DST transitions (#33134) 2026-03-09 09:16:45 +08:00
224 changed files with 36937 additions and 4933 deletions

33
.github/actions/setup-web/action.yml vendored Normal file
View File

@@ -0,0 +1,33 @@
name: Setup Web Environment
description: Setup pnpm, Node.js, and install web dependencies.
inputs:
node-version:
description: Node.js version to use
required: false
default: "22"
install-dependencies:
description: Whether to install web dependencies after setting up Node.js
required: false
default: "true"
runs:
using: composite
steps:
- name: Install pnpm
uses: pnpm/action-setup@41ff72655975bd51cab0327fa583b6e92b6d3061 # v4.2.0
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: ${{ inputs.node-version }}
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: ${{ inputs.install-dependencies == 'true' }}
shell: bash
run: pnpm --dir web install --frozen-lockfile

View File

@@ -24,14 +24,36 @@ updates:
schedule:
interval: "weekly"
open-pull-requests-limit: 2
ignore:
- dependency-name: "ky"
- dependency-name: "tailwind-merge"
update-types: ["version-update:semver-major"]
- dependency-name: "tailwindcss"
update-types: ["version-update:semver-major"]
- dependency-name: "react-markdown"
update-types: ["version-update:semver-major"]
- dependency-name: "react-syntax-highlighter"
update-types: ["version-update:semver-major"]
- dependency-name: "react-window"
update-types: ["version-update:semver-major"]
groups:
lexical:
patterns:
- "lexical"
- "@lexical/*"
storybook:
patterns:
- "storybook"
- "@storybook/*"
eslint-group:
patterns:
- "*eslint*"
npm-dependencies:
patterns:
- "*"
exclude-patterns:
- "lexical"
- "@lexical/*"
- "storybook"
- "@storybook/*"
- "*eslint*"

View File

@@ -22,12 +22,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -51,7 +51,7 @@ jobs:
run: sh .github/workflows/expose_service_ports.sh
- name: Set up Sandbox
uses: hoverkraft-tech/compose-action@v2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@@ -12,22 +12,22 @@ jobs:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v6
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'
@@ -84,4 +84,14 @@ jobs:
run: |
uvx --python 3.13 mdformat . --exclude ".agents/skills/**"
- uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
- name: Setup web environment
uses: ./.github/actions/setup-web
with:
node-version: "24"
- name: ESLint autofix
run: |
cd web
pnpm eslint --concurrency=2 --prune-suppressions
- uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3

View File

@@ -53,26 +53,26 @@ jobs:
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
- name: Build Docker image
id: build
uses: docker/build-push-action@v6
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
@@ -91,7 +91,7 @@ jobs:
touch "/tmp/digests/${sanitized_digest}"
- name: Upload digest
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
path: /tmp/digests/*
@@ -113,21 +113,21 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@v7
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*
merge-multiple: true
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0
with:
username: ${{ env.DOCKERHUB_USER }}
password: ${{ env.DOCKERHUB_TOKEN }}
- name: Extract metadata for Docker
id: meta
uses: docker/metadata-action@v5
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0
with:
images: ${{ env[matrix.image_name_env] }}
tags: |

View File

@@ -13,13 +13,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -40,7 +40,7 @@ jobs:
cp middleware.env.example middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml
@@ -63,13 +63,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: "3.12"
@@ -94,7 +94,7 @@ jobs:
sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env
- name: Set up Middlewares
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.middleware.yaml

View File

@@ -19,7 +19,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'build/feat/hitl'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}

View File

@@ -32,13 +32,13 @@ jobs:
context: "web"
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
- name: Build Docker Image
uses: docker/build-push-action@v6
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
with:
push: false
context: "{{defaultContext}}:${{ matrix.context }}"

View File

@@ -9,6 +9,6 @@ jobs:
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v6
- uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1
with:
sync-labels: true

View File

@@ -27,8 +27,8 @@ jobs:
vdb-changed: ${{ steps.changes.outputs.vdb }}
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@v6
- uses: dorny/paths-filter@v3
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
id: changes
with:
filters: |
@@ -39,6 +39,7 @@ jobs:
web:
- 'web/**'
- '.github/workflows/web-tests.yml'
- '.github/actions/setup-web/**'
vdb:
- 'api/core/rag/datasource/**'
- 'docker/**'

View File

@@ -21,7 +21,7 @@ jobs:
if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }}
steps:
- name: Download pyrefly diff artifact
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
@@ -49,7 +49,7 @@ jobs:
run: unzip -o pyrefly_diff.zip
- name: Post comment
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@@ -17,12 +17,12 @@ jobs:
pull-requests: write
steps:
- name: Checkout PR branch
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@v5
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
@@ -55,7 +55,7 @@ jobs:
echo ${{ github.event.pull_request.number }} > pr_number.txt
- name: Upload pyrefly diff
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: pyrefly_diff
path: |
@@ -64,7 +64,7 @@ jobs:
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
uses: actions/github-script@v8
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |

View File

@@ -16,6 +16,6 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check title
uses: amannn/action-semantic-pull-request@v6.1.1
uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v10
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0
with:
days-before-issue-stale: 15
days-before-issue-close: 3

View File

@@ -19,13 +19,13 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
api/**
@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: false
python-version: "3.12"
@@ -67,36 +67,22 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/style.yml
.github/actions/setup-web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
uses: ./.github/actions/setup-web
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -134,14 +120,14 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
**.sh
@@ -152,7 +138,7 @@ jobs:
.editorconfig
- name: Super-linter
uses: super-linter/super-linter/slim@v8
uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning

View File

@@ -21,12 +21,12 @@ jobs:
working-directory: sdks/nodejs-client
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Use Node.js
uses: actions/setup-node@v6
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0
with:
node-version: 22
cache: ''

View File

@@ -38,7 +38,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
@@ -48,18 +48,10 @@ jobs:
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
- name: Setup web environment
uses: ./.github/actions/setup-web
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
install-dependencies: "false"
- name: Detect changed files and generate diff
id: detect_changes
@@ -130,7 +122,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -21,7 +21,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
@@ -59,7 +59,7 @@ jobs:
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync

View File

@@ -19,19 +19,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Free Disk Space
uses: endersonmenezes/free-disk-space@v3
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
with:
remove_dotnet: true
remove_haskell: true
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # v7.3.1
with:
enable-cache: true
python-version: ${{ matrix.python-version }}
@@ -60,7 +60,7 @@ jobs:
# tiflash
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
uses: hoverkraft-tech/compose-action@v2.0.2
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
with:
compose-file: |
docker/docker-compose.yaml

View File

@@ -26,32 +26,19 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
@@ -70,28 +57,15 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Setup web environment
uses: ./.github/actions/setup-web
- name: Download blob reports
uses: actions/download-artifact@v6
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
path: web/.vitest-reports
pattern: blob-report-*
@@ -419,7 +393,7 @@ jobs:
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: web-coverage-report
path: web/coverage
@@ -435,36 +409,22 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v47
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
with:
files: |
web/**
.github/workflows/web-tests.yml
.github/actions/setup-web/**
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup NodeJS
uses: actions/setup-node@v6
- name: Setup web environment
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
uses: ./.github/actions/setup-web
- name: Web build check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
#### Customizing Suggested Questions

View File

@@ -44,7 +44,6 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
@@ -114,7 +113,6 @@ ignore_imports =
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.tool.tool_node -> models
dify_graph.nodes.agent.agent_node -> models.model
dify_graph.nodes.llm.file_saver -> core.helper.ssrf_proxy
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
@@ -135,7 +133,6 @@ ignore_imports =
dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager
dify_graph.nodes.tool.tool_node -> core.tools.errors
dify_graph.nodes.agent.agent_node -> extensions.ext_database
dify_graph.nodes.llm.file_saver -> extensions.ext_database
dify_graph.nodes.llm.node -> extensions.ext_database
dify_graph.nodes.tool.tool_node -> extensions.ext_database
dify_graph.nodes.agent.agent_node -> models

View File

@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.datetime_utils import naive_utc_now
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
@@ -2598,15 +2599,29 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
required=False,
default=None,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@@ -2618,8 +2633,10 @@ def migrate_oss(
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
dry_run: bool,
):
"""
@@ -2630,18 +2647,70 @@ def clean_expired_messages(
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
assert before_days is not None
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
if abs_mode:
assert start_from is not None
assert end_before is not None
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
elif from_days_ago is None:
assert before_days is not None
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
)
else:
assert before_days is not None
assert from_days_ago is not None
now = naive_utc_now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()

View File

@@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource):
console_ns.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
custom="max_keys_exceeded",
)
key = ApiToken.generate_api_key(self.token_prefix, 24)

View File

@@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html
from controllers.files import files_ns
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db as global_db
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -57,7 +56,7 @@ class ToolFileApi(Resource):
raise Forbidden("Invalid request.")
try:
tool_file_manager = ToolFileManager(engine=global_db.engine)
tool_file_manager = ToolFileManager()
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id,
)

View File

@@ -239,7 +239,7 @@ class MessageSuggestedQuestionApi(WebApiResource):
def get(self, app_model, end_user, message_id):
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotCompletionAppError()
raise NotChatAppError()
message_id = str(message_id)

View File

@@ -1,7 +1,6 @@
import hashlib
import logging
import time
from threading import Thread
from threading import Thread, Timer
from typing import Union
from flask import Flask, current_app
@@ -96,9 +95,9 @@ class MessageCycleManager:
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
time.sleep(1)
thread = Thread(
target=self._generate_conversation_name_worker,
thread = Timer(
1,
self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,

View File

@@ -10,28 +10,18 @@ from typing import Union
from uuid import uuid4
import httpx
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.helper import ssrf_proxy
from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
from sqlalchemy.engine import Engine
class ToolFileManager:
_engine: Engine
def __init__(self, engine: Engine | None = None):
if engine is None:
engine = global_db.engine
self._engine = engine
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
@@ -89,7 +79,7 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@@ -132,7 +122,7 @@ class ToolFileManager:
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@@ -157,7 +147,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(
@@ -181,7 +171,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
message_file: MessageFile | None = (
session.query(MessageFile)
.where(
@@ -225,7 +215,7 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
with session_factory.create_session() as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
.where(

View File

@@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN,
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT,
}

View File

@@ -250,6 +250,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.DATASOURCE:
@@ -292,6 +293,7 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
model_instance=model_instance,
memory=memory,
http_client=self._http_request_http_client,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import tempfile
import zipfile
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@@ -82,8 +83,18 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
value = variable.value
inputs = {"variable_selector": variable_selector}
if isinstance(value, list):
value = list(filter(lambda x: x, value))
process_data = {"documents": value if isinstance(value, list) else [value]}
if not value:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={"text": ArrayStringSegment(value=[])},
)
try:
if isinstance(value, list):
extracted_text_list = [
@@ -111,6 +122,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
else:
raise DocumentExtractorError(f"Unsupported variable type: {type(value)}")
except DocumentExtractorError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(e),
@@ -385,6 +397,32 @@ def parser_docx_part(block, doc: Document, content_items, i):
content_items.append((i, "table", Table(block, doc)))
def _normalize_docx_zip(file_content: bytes) -> bytes:
"""
Some DOCX files (e.g. exported by Evernote on Windows) are malformed:
ZIP entry names use backslash (\\) as path separator instead of the forward
slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry
"word\\document.xml" is never found when python-docx looks for
"word/document.xml", which triggers a KeyError about a missing relationship.
This function rewrites the ZIP in-memory, normalizing all entry names to
use forward slashes without touching any actual document content.
"""
try:
with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin:
out_buf = io.BytesIO()
with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout:
for item in zin.infolist():
data = zin.read(item.filename)
# Normalize backslash path separators to forward slash
item.filename = item.filename.replace("\\", "/")
zout.writestr(item, data)
return out_buf.getvalue()
except zipfile.BadZipFile:
# Not a valid zip — return as-is and let python-docx report the real error
return file_content
def _extract_text_from_docx(file_content: bytes) -> str:
"""
Extract text from a DOCX file.
@@ -392,7 +430,15 @@ def _extract_text_from_docx(file_content: bytes) -> str:
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
try:
doc = docx.Document(doc_file)
except Exception as e:
logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e)
# Some DOCX files exported by tools like Evernote on Windows use
# backslash path separators in ZIP entries and/or single-quoted XML
# attributes, both of which break python-docx on Linux. Normalize and retry.
file_content = _normalize_docx_zip(file_content)
doc = docx.Document(io.BytesIO(file_content))
text = []
# Keep track of paragraph and table positions

View File

@@ -14,7 +14,6 @@ from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from dify_graph.variables import (
ArrayFileSegment,
@@ -23,7 +22,11 @@ from dify_graph.variables import (
)
from dify_graph.variables.segments import ArrayObjectSegment
from .entities import KnowledgeRetrievalNodeData
from .entities import (
Condition,
KnowledgeRetrievalNodeData,
MetadataFilteringCondition,
)
from .exc import (
KnowledgeRetrievalNodeError,
RateLimitExceededError,
@@ -43,8 +46,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
@@ -52,8 +53,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
id=id,
@@ -65,14 +64,6 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
)
self._llm_file_saver = llm_file_saver
@classmethod
def version(cls):
return "1"
@@ -171,6 +162,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
resolved_metadata_conditions = (
self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions)
if node_data.metadata_filtering_conditions
else None
)
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
@@ -189,7 +186,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
)
@@ -247,7 +244,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_conditions=resolved_metadata_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
@@ -256,6 +253,48 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
def _resolve_metadata_filtering_conditions(
self, conditions: MetadataFilteringCondition
) -> MetadataFilteringCondition:
if conditions.conditions is None:
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator,
conditions=None,
)
variable_pool = self.graph_runtime_state.variable_pool
resolved_conditions: list[Condition] = []
for cond in conditions.conditions or []:
value = cond.value
if isinstance(value, str):
segment_group = variable_pool.convert_template(value)
if len(segment_group.value) == 1:
resolved_value = segment_group.value[0].to_object()
else:
resolved_value = segment_group.text
elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value):
resolved_values = []
for v in value: # type: ignore
segment_group = variable_pool.convert_template(v)
if len(segment_group.value) == 1:
resolved_values.append(segment_group.value[0].to_object())
else:
resolved_values.append(segment_group.text)
resolved_value = resolved_values
else:
resolved_value = value
resolved_conditions.append(
Condition(
name=cond.name,
comparison_operator=cond.comparison_operator,
value=resolved_value,
)
)
return MetadataFilteringCondition(
logical_operator=conditions.logical_operator or "and",
conditions=resolved_conditions,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -1,14 +1,11 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from dify_graph.file import File, FileTransferMethod, FileType
from extensions.ext_database import db as global_db
from dify_graph.nodes.protocols import HttpClientProtocol
class LLMFileSaver(tp.Protocol):
@@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol):
raise NotImplementedError()
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol):
self._user_id = user_id
self._tenant_id = tenant_id
self._http_client = http_client
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
return ToolFileManager()
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response = self._http_client.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")

View File

@@ -64,6 +64,7 @@ from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
@@ -127,6 +128,7 @@ class LLMNode(Node[LLMNodeData]):
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -149,6 +151,7 @@ class LLMNode(Node[LLMNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@@ -28,6 +28,7 @@ from dify_graph.nodes.llm import (
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@@ -68,6 +69,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -90,6 +92,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_file_saver = FileSaverImpl(
user_id=dify_ctx.user_id,
tenant_id=dify_ctx.tenant_id,
http_client=http_client,
)
self._llm_file_saver = llm_file_saver

View File

@@ -21,6 +21,10 @@ celery_redis = Redis(
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None,
# Add conservative socket timeouts and health checks to avoid long-lived half-open sockets
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30,
)
logger = logging.getLogger(__name__)

View File

@@ -3,6 +3,7 @@ import math
import time
from collections.abc import Iterable, Sequence
from celery import group
from sqlalchemy import ColumnElement, and_, func, or_, select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
@@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None:
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
enqueued: int = 0
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
if not is_locked:
continue
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
enqueued += 1
if not any(acquired):
continue
jobs = [
trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id)
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired)
if is_locked
]
result = group(jobs).apply_async()
enqueued = len(jobs)
logger.info(
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s",
page + 1,
pages,
len(subscriptions),
sum(1 for x in acquired if x),
enqueued,
result,
)
logger.info("Trigger refresh scan done: due=%d", total_due)

View File

@@ -1,6 +1,6 @@
import logging
from celery import group, shared_task
from celery import current_app, group, shared_task
from sqlalchemy import and_, select
from sqlalchemy.orm import Session, sessionmaker
@@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None:
with session_factory() as session:
total_dispatched = 0
# Process in batches until we've handled all due schedules or hit the limit
while True:
due_schedules = _fetch_due_schedules(session)
if not due_schedules:
break
dispatched_count = _process_schedules(session, due_schedules)
total_dispatched += dispatched_count
with current_app.producer_or_acquire() as producer: # type: ignore
dispatched_count = _process_schedules(session, due_schedules, producer)
total_dispatched += dispatched_count
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if (
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
):
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
logger.debug("Batch processed: %d dispatched", dispatched_count)
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched:
logger.warning(
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
)
break
if total_dispatched > 0:
logger.info("Total processed: %d dispatched", total_dispatched)
logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched)
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
@@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
return list(due_schedules)
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int:
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int:
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
if not schedules:
return 0
@@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan])
if tasks_to_dispatch:
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
job.apply_async()
job.apply_async(producer=producer)
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))

View File

@@ -12,6 +12,7 @@ from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import (
App,
AppAnnotationHitHistory,
@@ -142,7 +143,7 @@ class MessagesCleanService:
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
end_before = naive_utc_now() - datetime.timedelta(days=days)
logger.info(
"clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s",

View File

@@ -1,9 +1,10 @@
import logging
import time
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from typing import Any, Protocol
import click
from celery import shared_task
from celery import current_app, shared_task
from configs import dify_config
from core.db.session_factory import session_factory
@@ -19,6 +20,12 @@ from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
class CeleryTaskLike(Protocol):
def delay(self, *args: Any, **kwargs: Any) -> Any: ...
def apply_async(self, *args: Any, **kwargs: Any) -> Any: ...
@shared_task(queue="dataset")
def document_indexing_task(dataset_id: str, document_ids: list):
"""
@@ -179,8 +186,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
def _document_indexing_with_tenant_queue(
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None]
):
tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike
) -> None:
try:
_document_indexing(dataset_id, document_ids)
except Exception:
@@ -201,16 +208,20 @@ def _document_indexing_with_tenant_queue(
logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks)
if next_tasks:
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.delay( # type: ignore
tenant_id=document_task.tenant_id,
dataset_id=document_task.dataset_id,
document_ids=document_task.document_ids,
)
with current_app.producer_or_acquire() as producer: # type: ignore
for next_task in next_tasks:
document_task = DocumentTask(**next_task)
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
task_func.apply_async(
kwargs={
"tenant_id": document_task.tenant_id,
"dataset_id": document_task.dataset_id,
"document_ids": document_task.document_ids,
},
producer=producer,
)
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -3,12 +3,13 @@ import json
import logging
import time
import uuid
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from itertools import islice
from typing import Any
import click
from celery import shared_task # type: ignore
from celery import group, shared_task
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
@@ -27,6 +28,11 @@ from services.file_service import FileService
logger = logging.getLogger(__name__)
def chunked(iterable: Sequence, size: int):
it = iter(iterable)
return iter(lambda: list(islice(it, size)), [])
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
@@ -83,16 +89,24 @@ def rag_pipeline_run_task(
logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids)
if next_file_ids:
for next_file_id in next_file_ids:
# Process the next waiting task
# Keep the flag set to indicate a task is running
tenant_isolated_task_queue.set_task_waiting_time()
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
for batch in chunked(next_file_ids, 100):
jobs = []
for next_file_id in batch:
tenant_isolated_task_queue.set_task_waiting_time()
file_id = (
next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id
)
jobs.append(
rag_pipeline_run_task.s(
rag_pipeline_invoke_entities_file_id=file_id,
tenant_id=tenant_id,
)
)
if jobs:
group(jobs).apply_async()
else:
# No more waiting tasks, clear the flag
tenant_isolated_task_queue.delete_task_key()

View File

@@ -11,6 +11,7 @@ from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
@@ -74,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
http_client=MagicMock(spec=HttpClientProtocol),
)
return node

View File

@@ -322,11 +322,14 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once_with(
tenant_id=next_task["tenant_id"],
dataset_id=next_task["dataset_id"],
document_ids=next_task["document_ids"],
)
# apply_async is used by implementation; assert it was called once with expected kwargs
assert task_dispatch_spy.apply_async.call_count == 1
call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {})
assert call_kwargs == {
"tenant_id": next_task["tenant_id"],
"dataset_id": next_task["dataset_id"],
"document_ids": next_task["document_ids"],
}
set_waiting_spy.assert_called_once()
delete_key_spy.assert_not_called()
@@ -352,7 +355,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_not_called()
task_dispatch_spy.apply_async.assert_not_called()
delete_key_spy.assert_called_once()
def test_validation_failure_sets_error_status_when_vector_space_at_limit(
@@ -447,7 +450,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
task_dispatch_spy.delay.assert_called_once()
task_dispatch_spy.apply_async.assert_called_once()
def test_sessions_close_on_successful_indexing(
self,
@@ -534,7 +537,7 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == concurrency_limit
assert task_dispatch_spy.apply_async.call_count == concurrency_limit
assert set_waiting_spy.call_count == concurrency_limit
def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies):
@@ -565,9 +568,10 @@ class TestDatasetIndexingTaskIntegration:
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
# Assert
assert task_dispatch_spy.delay.call_count == 3
assert task_dispatch_spy.apply_async.call_count == 3
for index, expected_task in enumerate(ordered_tasks):
assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"]
call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {})
assert call_kwargs.get("document_ids") == expected_task["document_ids"]
def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies):
"""Skip limit checks when billing feature is disabled."""

View File

@@ -762,11 +762,12 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify task function was called for each waiting task
assert mock_task_func.delay.call_count == 1
assert mock_task_func.apply_async.call_count == 1
# Verify correct parameters for each call
calls = mock_task_func.delay.call_args_list
assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
calls = mock_task_func.apply_async.call_args_list
sent_kwargs = calls[0][1]["kwargs"]
assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
# Verify queue is empty after processing (tasks were pulled)
remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added
@@ -830,11 +831,15 @@ class TestDocumentIndexingTasks:
assert updated_document.processing_started_at is not None
# Verify waiting task was still processed despite core processing error
mock_task_func.delay.assert_called_once()
mock_task_func.apply_async.assert_called_once()
# Verify correct parameters for the call
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]}
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant_id,
"dataset_id": dataset_id,
"document_ids": ["waiting-doc-1"],
}
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -896,9 +901,13 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Verify only tenant1's waiting task was processed
mock_task_func.delay.assert_called_once()
call = mock_task_func.delay.call_args
assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]}
mock_task_func.apply_async.assert_called_once()
call = mock_task_func.apply_async.call_args
assert call[1]["kwargs"] == {
"tenant_id": tenant1_id,
"dataset_id": dataset1_id,
"document_ids": ["tenant1-doc-1"],
}
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)

View File

@@ -1,6 +1,6 @@
import json
import uuid
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
@@ -388,8 +388,10 @@ class TestRagPipelineRunTasks:
# Set the task key to indicate there are waiting tasks (legacy behavior)
redis_client.set(legacy_task_key, 1, ex=60 * 60)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the priority task with new code but legacy queue data
rag_pipeline_run_task(file_id, tenant.id)
@@ -398,13 +400,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group, pull 1 task a time by default
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify that new code can process legacy queue entries
# The new TenantIsolatedTaskQueue should be able to read from the legacy format
@@ -446,8 +449,10 @@ class TestRagPipelineRunTasks:
waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)]
queue.push_tasks(waiting_file_ids)
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task
rag_pipeline_run_task(file_id, tenant.id)
@@ -456,13 +461,14 @@ class TestRagPipelineRunTasks:
mock_file_service["delete_file"].assert_called_once_with(file_id)
assert mock_pipeline_generator.call_count == 1
# Verify waiting tasks were processed, pull 1 task a time by default
assert mock_delay.call_count == 1
# Verify waiting tasks were processed via group.apply_async
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0]
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue still has remaining tasks (only 1 was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -557,8 +563,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task (should not raise exception)
rag_pipeline_run_task(file_id, tenant.id)
@@ -569,12 +577,13 @@ class TestRagPipelineRunTasks:
assert mock_pipeline_generator.call_count == 1
# Verify waiting task was still processed despite core processing error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)
@@ -684,8 +693,10 @@ class TestRagPipelineRunTasks:
queue1.push_tasks([waiting_file_id1])
queue2.push_tasks([waiting_file_id2])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act: Execute the regular task for tenant1 only
rag_pipeline_run_task(file_id1, tenant1.id)
@@ -694,11 +705,12 @@ class TestRagPipelineRunTasks:
assert mock_file_service["delete_file"].call_count == 1
assert mock_pipeline_generator.call_count == 1
# Verify only tenant1's waiting task was processed
mock_delay.assert_called_once()
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert call_kwargs.get("tenant_id") == tenant1.id
# Verify only tenant1's waiting task was processed (via group)
assert mock_group.return_value.apply_async.called
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1
assert first_kwargs.get("tenant_id") == tenant1.id
# Verify tenant1's queue is empty
remaining_tasks1 = queue1.pull_tasks(count=10)
@@ -913,8 +925,10 @@ class TestRagPipelineRunTasks:
waiting_file_id = str(uuid.uuid4())
queue.push_tasks([waiting_file_id])
# Mock the task function calls
with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay:
# Mock the Celery group scheduling used by the implementation
with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group:
mock_group.return_value.apply_async = MagicMock()
# Act & Assert: Execute the regular task (should raise Exception)
with pytest.raises(Exception, match="File not found"):
rag_pipeline_run_task(file_id, tenant.id)
@@ -924,12 +938,13 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
# Verify waiting task was still processed despite file error
mock_delay.assert_called_once()
assert mock_group.return_value.apply_async.called
# Verify correct parameters for the call
call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {}
assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert call_kwargs.get("tenant_id") == tenant.id
# Verify correct parameters for the first scheduled job signature
jobs = mock_group.call_args.args[0] if mock_group.call_args else []
first_kwargs = jobs[0].kwargs if jobs else {}
assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id
assert first_kwargs.get("tenant_id") == tenant.id
# Verify queue is empty after processing (task was pulled)
remaining_tasks = queue.pull_tasks(count=10)

View File

@@ -105,18 +105,26 @@ def app_model(
class MockCeleryGroup:
"""Mock for celery group() function that collects dispatched tasks."""
"""Mock for celery group() function that collects dispatched tasks.
Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async
(e.g. producer) so production code can pass broker-related options without
breaking tests.
"""
def __init__(self) -> None:
self.collected: list[dict[str, Any]] = []
self._applied = False
self.last_apply_async_kwargs: dict[str, Any] | None = None
def __call__(self, items: Any) -> MockCeleryGroup:
self.collected = list(items)
return self
def apply_async(self) -> None:
def apply_async(self, **kwargs: Any) -> None:
# Accept arbitrary kwargs like producer to be compatible with Celery
self._applied = True
self.last_apply_async_kwargs = kwargs
@property
def applied(self) -> bool:

View File

@@ -0,0 +1,181 @@
import datetime
import re
from unittest.mock import MagicMock, patch
import click
import pytest
from commands import clean_expired_messages
def _mock_service() -> MagicMock:
service = MagicMock()
service.run.return_value = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
return service
def test_absolute_mode_calls_from_time_range():
policy = object()
service = _mock_service()
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
):
clean_expired_messages.callback(
batch_size=200,
graceful_period=21,
start_from=start_from,
end_before=end_before,
from_days_ago=None,
before_days=None,
dry_run=True,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=200,
dry_run=True,
)
mock_from_days.assert_not_called()
def test_relative_mode_before_days_only_calls_from_days():
policy = object()
service = _mock_service()
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
):
clean_expired_messages.callback(
batch_size=500,
graceful_period=14,
start_from=None,
end_before=None,
from_days_ago=None,
before_days=30,
dry_run=False,
)
mock_from_days.assert_called_once_with(
policy=policy,
days=30,
batch_size=500,
dry_run=False,
)
mock_from_time_range.assert_not_called()
def test_relative_mode_with_from_days_ago_calls_from_time_range():
policy = object()
service = _mock_service()
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
patch("commands.naive_utc_now", return_value=fixed_now),
):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=None,
end_before=None,
from_days_ago=60,
before_days=30,
dry_run=False,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=fixed_now - datetime.timedelta(days=60),
end_before=fixed_now - datetime.timedelta(days=30),
batch_size=1000,
dry_run=False,
)
mock_from_days.assert_not_called()
@pytest.mark.parametrize(
("kwargs", "message"),
[
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": datetime.datetime(2024, 2, 1),
"from_days_ago": None,
"before_days": 30,
},
"mutually exclusive",
),
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"Both --start-from and --end-before are required",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 10,
"before_days": None,
},
"--from-days-ago must be used together with --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": -1,
},
"--before-days must be >= 0",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 30,
"before_days": 30,
},
"--from-days-ago must be greater than --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
),
],
)
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
with pytest.raises(click.UsageError, match=re.escape(message)):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=kwargs["start_from"],
end_before=kwargs["end_before"],
from_days_ago=kwargs["from_days_ago"],
before_days=kwargs["before_days"],
dry_run=False,
)

View File

@@ -0,0 +1,70 @@
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
RemoteFileUploadError,
TooManyFilesError,
UnsupportedFileTypeError,
)
class TestFilenameNotExistsError:
def test_defaults(self):
error = FilenameNotExistsError()
assert error.code == 400
assert error.description == "The specified filename does not exist."
class TestRemoteFileUploadError:
def test_defaults(self):
error = RemoteFileUploadError()
assert error.code == 400
assert error.description == "Error uploading remote file."
class TestFileTooLargeError:
def test_defaults(self):
error = FileTooLargeError()
assert error.code == 413
assert error.error_code == "file_too_large"
assert error.description == "File size exceeded. {message}"
class TestUnsupportedFileTypeError:
def test_defaults(self):
error = UnsupportedFileTypeError()
assert error.code == 415
assert error.error_code == "unsupported_file_type"
assert error.description == "File type not allowed."
class TestBlockedFileExtensionError:
def test_defaults(self):
error = BlockedFileExtensionError()
assert error.code == 400
assert error.error_code == "file_extension_blocked"
assert error.description == "The file extension is blocked for security reasons."
class TestTooManyFilesError:
def test_defaults(self):
error = TooManyFilesError()
assert error.code == 400
assert error.error_code == "too_many_files"
assert error.description == "Only one file is allowed."
class TestNoFileUploadedError:
def test_defaults(self):
error = NoFileUploadedError()
assert error.code == 400
assert error.error_code == "no_file_uploaded"
assert error.description == "Please upload your file."

View File

@@ -1,22 +1,95 @@
from flask import Response
from controllers.common.file_response import enforce_download_for_html, is_html_content
from controllers.common.file_response import (
_normalize_mime_type,
enforce_download_for_html,
is_html_content,
)
class TestFileResponseHelpers:
def test_is_html_content_detects_mime_type(self):
class TestNormalizeMimeType:
def test_returns_empty_string_for_none(self):
assert _normalize_mime_type(None) == ""
def test_returns_empty_string_for_empty_string(self):
assert _normalize_mime_type("") == ""
def test_normalizes_mime_type(self):
assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html"
class TestIsHtmlContent:
def test_detects_html_via_mime_type(self):
mime_type = "text/html; charset=UTF-8"
result = is_html_content(mime_type, filename="file.txt", extension="txt")
result = is_html_content(
mime_type=mime_type,
filename="file.txt",
extension="txt",
)
assert result is True
def test_is_html_content_detects_extension(self):
result = is_html_content("text/plain", filename="report.html", extension=None)
def test_detects_html_via_extension_argument(self):
result = is_html_content(
mime_type="text/plain",
filename=None,
extension="html",
)
assert result is True
def test_enforce_download_for_html_sets_headers(self):
def test_detects_html_via_filename_extension(self):
result = is_html_content(
mime_type="text/plain",
filename="report.html",
extension=None,
)
assert result is True
def test_returns_false_when_no_html_detected_anywhere(self):
"""
Missing negative test:
- MIME type is not HTML
- filename has no HTML extension
- extension argument is not HTML
"""
result = is_html_content(
mime_type="application/json",
filename="data.json",
extension="json",
)
assert result is False
def test_returns_false_when_all_inputs_are_none(self):
result = is_html_content(
mime_type=None,
filename=None,
extension=None,
)
assert result is False
class TestEnforceDownloadForHtml:
def test_sets_attachment_when_filename_missing(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
response,
mime_type="text/html",
filename=None,
extension="html",
)
assert updated is True
assert response.headers["Content-Disposition"] == "attachment"
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_sets_headers_when_filename_present(self):
response = Response("payload", mimetype="text/html")
updated = enforce_download_for_html(
@@ -27,11 +100,12 @@ class TestFileResponseHelpers:
)
assert updated is True
assert "attachment" in response.headers["Content-Disposition"]
assert response.headers["Content-Disposition"].startswith("attachment")
assert "unsafe.html" in response.headers["Content-Disposition"]
assert response.headers["Content-Type"] == "application/octet-stream"
assert response.headers["X-Content-Type-Options"] == "nosniff"
def test_enforce_download_for_html_no_change_for_non_html(self):
def test_does_not_modify_response_for_non_html_content(self):
response = Response("payload", mimetype="text/plain")
updated = enforce_download_for_html(

View File

@@ -0,0 +1,188 @@
from uuid import UUID
import httpx
import pytest
from controllers.common import helpers
from controllers.common.helpers import FileInfo, guess_file_info_from_response
def make_response(
url="https://example.com/file.txt",
headers=None,
content=None,
):
return httpx.Response(
200,
request=httpx.Request("GET", url),
headers=headers or {},
content=content or b"",
)
class TestGuessFileInfoFromResponse:
def test_filename_from_url(self):
response = make_response(
url="https://example.com/test.pdf",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "test.pdf"
assert info.extension == ".pdf"
assert info.mimetype == "application/pdf"
def test_filename_from_content_disposition(self):
headers = {
"Content-Disposition": "attachment; filename=myfile.csv",
"Content-Type": "text/csv",
}
response = make_response(
url="https://example.com/",
headers=headers,
content=b"Hello World",
)
info = guess_file_info_from_response(response)
assert info.filename == "myfile.csv"
assert info.extension == ".csv"
assert info.mimetype == "text/csv"
@pytest.mark.parametrize(
("magic_available", "expected_ext"),
[
(True, "txt"),
(False, "bin"),
],
)
def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext):
if magic_available:
if helpers.magic is None:
pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant")
else:
monkeypatch.setattr(helpers, "magic", None)
response = make_response(
url="https://example.com/",
content=b"Hello World",
)
info = guess_file_info_from_response(response)
name, ext = info.filename.split(".")
UUID(name)
assert ext == expected_ext
def test_mimetype_from_header_when_unknown(self):
headers = {"Content-Type": "application/json"}
response = make_response(
url="https://example.com/file.unknown",
headers=headers,
content=b'{"a": 1}',
)
info = guess_file_info_from_response(response)
assert info.mimetype == "application/json"
def test_extension_added_when_missing(self):
headers = {"Content-Type": "image/png"}
response = make_response(
url="https://example.com/image",
headers=headers,
content=b"fakepngdata",
)
info = guess_file_info_from_response(response)
assert info.extension == ".png"
assert info.filename.endswith(".png")
def test_content_length_used_as_size(self):
headers = {
"Content-Length": "1234",
"Content-Type": "text/plain",
}
response = make_response(
url="https://example.com/a.txt",
headers=headers,
content=b"a" * 1234,
)
info = guess_file_info_from_response(response)
assert info.size == 1234
def test_size_minus_one_when_header_missing(self):
response = make_response(url="https://example.com/a.txt")
info = guess_file_info_from_response(response)
assert info.size == -1
def test_fallback_to_bin_extension(self):
headers = {"Content-Type": "application/octet-stream"}
response = make_response(
url="https://example.com/download",
headers=headers,
content=b"\x00\x01\x02\x03",
)
info = guess_file_info_from_response(response)
assert info.extension == ".bin"
assert info.filename.endswith(".bin")
def test_return_type(self):
response = make_response()
info = guess_file_info_from_response(response)
assert isinstance(info, FileInfo)
class TestMagicImportWarnings:
@pytest.mark.parametrize(
("platform_name", "expected_message"),
[
("Windows", "pip install python-magic-bin"),
("Darwin", "brew install libmagic"),
("Linux", "sudo apt-get install libmagic1"),
("Other", "install `libmagic`"),
],
)
def test_magic_import_warning_per_platform(
self,
monkeypatch,
platform_name,
expected_message,
):
import builtins
import importlib
# Force ImportError when "magic" is imported
real_import = builtins.__import__
def fake_import(name, *args, **kwargs):
if name == "magic":
raise ImportError("No module named magic")
return real_import(name, *args, **kwargs)
monkeypatch.setattr(builtins, "__import__", fake_import)
monkeypatch.setattr("platform.system", lambda: platform_name)
# Remove helpers so it imports fresh
import sys
original_helpers = sys.modules.get(helpers.__name__)
sys.modules.pop(helpers.__name__, None)
try:
with pytest.warns(UserWarning, match="To use python-magic") as warning:
imported_helpers = importlib.import_module(helpers.__name__)
assert expected_message in str(warning[0].message)
finally:
if original_helpers is not None:
sys.modules[helpers.__name__] = original_helpers

View File

@@ -0,0 +1,189 @@
import sys
from enum import StrEnum
from unittest.mock import MagicMock, patch
import pytest
from flask_restx import Namespace
from pydantic import BaseModel
class UserModel(BaseModel):
id: int
name: str
class ProductModel(BaseModel):
id: int
price: float
@pytest.fixture(autouse=True)
def mock_console_ns():
"""Mock the console_ns to avoid circular imports during test collection."""
mock_ns = MagicMock(spec=Namespace)
mock_ns.models = {}
# Inject mock before importing schema module
with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}):
yield mock_ns
def test_default_ref_template_value():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0
assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}"
def test_register_schema_model_calls_namespace_schema_model():
from controllers.common.schema import register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "UserModel"
assert isinstance(schema, dict)
assert "properties" in schema
def test_register_schema_model_passes_schema_from_pydantic():
from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model
namespace = MagicMock(spec=Namespace)
register_schema_model(namespace, UserModel)
schema = namespace.schema_model.call_args.args[1]
expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
assert schema == expected_schema
def test_register_schema_models_registers_multiple_models():
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
register_schema_models(namespace, UserModel, ProductModel)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["UserModel", "ProductModel"]
def test_register_schema_models_calls_register_schema_model(monkeypatch):
from controllers.common.schema import register_schema_models
namespace = MagicMock(spec=Namespace)
calls = []
def fake_register(ns, model):
calls.append((ns, model))
monkeypatch.setattr(
"controllers.common.schema.register_schema_model",
fake_register,
)
register_schema_models(namespace, UserModel, ProductModel)
assert calls == [
(namespace, UserModel),
(namespace, ProductModel),
]
class StatusEnum(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
class PriorityEnum(StrEnum):
HIGH = "high"
LOW = "low"
def test_get_or_create_model_returns_existing_model(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"TestModel": existing_model}
result = get_or_create_model("TestModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
mock_console_ns.models = {}
new_model = MagicMock()
mock_console_ns.model.return_value = new_model
field_def = {"name": {"type": "string"}}
result = get_or_create_model("NewModel", field_def)
assert result == new_model
mock_console_ns.model.assert_called_once_with("NewModel", field_def)
def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns):
from controllers.common.schema import get_or_create_model
existing_model = MagicMock()
mock_console_ns.models = {"ExistingModel": existing_model}
result = get_or_create_model("ExistingModel", {"key": "value"})
assert result == existing_model
mock_console_ns.model.assert_not_called()
def test_register_enum_models_registers_single_enum():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
namespace.schema_model.assert_called_once()
model_name, schema = namespace.schema_model.call_args.args
assert model_name == "StatusEnum"
assert isinstance(schema, dict)
def test_register_enum_models_registers_multiple_enums():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum, PriorityEnum)
assert namespace.schema_model.call_count == 2
called_names = [call.args[0] for call in namespace.schema_model.call_args_list]
assert called_names == ["StatusEnum", "PriorityEnum"]
def test_register_enum_models_uses_correct_ref_template():
from controllers.common.schema import register_enum_models
namespace = MagicMock(spec=Namespace)
register_enum_models(namespace, StatusEnum)
schema = namespace.schema_model.call_args.args[1]
# Verify the schema contains enum values
assert "enum" in schema or "anyOf" in schema

View File

@@ -0,0 +1,92 @@
from __future__ import annotations
from controllers.console.app import annotation as annotation_module
def test_annotation_reply_payload_valid():
"""Test AnnotationReplyPayload with valid data."""
payload = annotation_module.AnnotationReplyPayload(
score_threshold=0.5,
embedding_provider_name="openai",
embedding_model_name="text-embedding-3-small",
)
assert payload.score_threshold == 0.5
assert payload.embedding_provider_name == "openai"
assert payload.embedding_model_name == "text-embedding-3-small"
def test_annotation_setting_update_payload_valid():
"""Test AnnotationSettingUpdatePayload with valid data."""
payload = annotation_module.AnnotationSettingUpdatePayload(
score_threshold=0.75,
)
assert payload.score_threshold == 0.75
def test_annotation_list_query_defaults():
"""Test AnnotationListQuery with default parameters."""
query = annotation_module.AnnotationListQuery()
assert query.page == 1
assert query.limit == 20
assert query.keyword == ""
def test_annotation_list_query_custom_page():
"""Test AnnotationListQuery with custom page."""
query = annotation_module.AnnotationListQuery(page=3, limit=50)
assert query.page == 3
assert query.limit == 50
def test_annotation_list_query_with_keyword():
"""Test AnnotationListQuery with keyword."""
query = annotation_module.AnnotationListQuery(keyword="test")
assert query.keyword == "test"
def test_create_annotation_payload_with_message_id():
"""Test CreateAnnotationPayload with message ID."""
payload = annotation_module.CreateAnnotationPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
question="What is AI?",
)
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"
assert payload.question == "What is AI?"
def test_create_annotation_payload_with_text():
"""Test CreateAnnotationPayload with text content."""
payload = annotation_module.CreateAnnotationPayload(
question="What is ML?",
answer="Machine learning is...",
)
assert payload.question == "What is ML?"
assert payload.answer == "Machine learning is..."
def test_update_annotation_payload():
"""Test UpdateAnnotationPayload."""
payload = annotation_module.UpdateAnnotationPayload(
question="Updated question",
answer="Updated answer",
)
assert payload.question == "Updated question"
assert payload.answer == "Updated answer"
def test_annotation_reply_status_query_enable():
"""Test AnnotationReplyStatusQuery with enable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="enable")
assert query.action == "enable"
def test_annotation_reply_status_query_disable():
"""Test AnnotationReplyStatusQuery with disable action."""
query = annotation_module.AnnotationReplyStatusQuery(action="disable")
assert query.action == "disable"
def test_annotation_file_payload_valid():
"""Test AnnotationFilePayload with valid message ID."""
payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000"

View File

@@ -13,6 +13,9 @@ from pandas.errors import ParserError
from werkzeug.datastructures import FileStorage
from configs import dify_config
from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit
from services.annotation_service import AppAnnotationService
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
class TestAnnotationImportRateLimiting:
@@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account):
"""Test that per-minute rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-minute limit
mock_redis.zcard.side_effect = [
dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check
@@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account):
"""Test that per-hour rate limit is enforced."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate exceeding per-hour limit
mock_redis.zcard.side_effect = [
@@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting:
def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account):
"""Test that requests within limits are allowed."""
from controllers.console.wraps import annotation_import_rate_limit
# Simulate being under both limits
mock_redis.zcard.return_value = 2
@@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_limit_enforced(self, mock_redis, mock_current_account):
"""Test that concurrent task limit is enforced."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate max concurrent tasks already running
mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT
@@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl:
def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account):
"""Test that requests within concurrency limits are allowed."""
from controllers.console.wraps import annotation_import_concurrency_limit
# Simulate being under concurrent task limit
mock_redis.zcard.return_value = 1
@@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl:
def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account):
"""Test that old/stale job entries are removed."""
from controllers.console.wraps import annotation_import_concurrency_limit
mock_redis.zcard.return_value = 0
@@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation:
def test_max_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too many records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with too many records
max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS
@@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation:
def test_min_records_limit_enforced(self, mock_app, mock_db_session):
"""Test that files with too few valid records are rejected."""
from services.annotation_service import AppAnnotationService
# Create CSV with only header (no data rows)
csv_content = "question,answer\n"
@@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation:
def test_invalid_csv_format_handled(self, mock_app, mock_db_session):
"""Test that invalid CSV format is handled gracefully."""
from services.annotation_service import AppAnnotationService
# Any content is fine once we force ParserError
csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff'
@@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation:
def test_valid_import_succeeds(self, mock_app, mock_db_session):
"""Test that valid import request succeeds."""
from services.annotation_service import AppAnnotationService
# Create valid CSV
csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n"
@@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation:
class TestAnnotationImportTaskOptimization:
"""Test optimizations in batch import task."""
def test_task_has_timeout_configured(self):
"""Test that task has proper timeout configuration."""
from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task
# Verify task configuration
assert hasattr(batch_import_annotations_task, "time_limit")
assert hasattr(batch_import_annotations_task, "soft_time_limit")
# Check timeout values are reasonable
# Hard limit should be 6 minutes (360s)
# Soft limit should be 5 minutes (300s)
# Note: actual values depend on Celery configuration
def test_task_is_registered_with_queue(self):
"""Test that task is registered with the correct queue."""
assert hasattr(batch_import_annotations_task, "apply_async")
assert hasattr(batch_import_annotations_task, "delay")
class TestConfigurationValues:

View File

@@ -0,0 +1,585 @@
"""
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
Target: increase coverage for files with <75% coverage.
"""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import (
annotation as annotation_module,
)
from controllers.console.app import (
completion as completion_module,
)
from controllers.console.app import (
message as message_module,
)
from controllers.console.app import (
ops_trace as ops_trace_module,
)
from controllers.console.app import (
site as site_module,
)
from controllers.console.app import (
statistic as statistic_module,
)
from controllers.console.app import (
workflow_app_log as workflow_app_log_module,
)
from controllers.console.app import (
workflow_draft_variable as workflow_draft_variable_module,
)
from controllers.console.app import (
workflow_statistic as workflow_statistic_module,
)
from controllers.console.app import (
workflow_trigger as workflow_trigger_module,
)
from controllers.console.app import (
wraps as wraps_module,
)
from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload
from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload
from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery
from controllers.console.app.site import AppSiteUpdatePayload
from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload
from controllers.console.app.workflow_app_log import WorkflowAppLogQuery
from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload
from controllers.console.app.workflow_statistic import WorkflowStatisticQuery
from controllers.console.app.workflow_trigger import Parser, ParserEnable
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
# ========== Completion Tests ==========
class TestCompletionEndpoints:
"""Tests for completion API endpoints."""
def test_completion_create_payload(self):
"""Test completion creation payload."""
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
assert payload.inputs == {"prompt": "test"}
def test_chat_message_payload_uuid_validation(self):
payload = ChatMessagePayload(
inputs={},
model_config={},
query="hi",
conversation_id=str(uuid.uuid4()),
parent_message_id=str(uuid.uuid4()),
)
assert payload.query == "hi"
def test_completion_api_success(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: {"text": "ok"},
)
monkeypatch.setattr(
completion_module.helper,
"compact_generate_response",
lambda response: {"result": response},
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
resp = method(app_model=MagicMock(id="app-1"))
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(
completion_module.services.errors.conversation.ConversationNotExistsError()
),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(NotFound):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app, monkeypatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
class DummyAccount:
pass
dummy_account = DummyAccount()
monkeypatch.setattr(completion_module, "current_user", dummy_account)
monkeypatch.setattr(completion_module, "Account", DummyAccount)
monkeypatch.setattr(
completion_module.AppGenerateService,
"generate",
lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()),
)
with app.test_request_context(
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(app_model=MagicMock(id="app-1"))
# ========== OpsTrace Tests ==========
class TestOpsTraceEndpoints:
"""Tests for ops_trace endpoint."""
def test_ops_trace_query_basic(self):
"""Test ops_trace query."""
query = TraceProviderQuery(tracing_provider="langfuse")
assert query.tracing_provider == "langfuse"
def test_ops_trace_config_payload(self):
payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"})
assert payload.tracing_config["api_key"] == "k"
def test_trace_app_config_get_empty(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
monkeypatch.setattr(
ops_trace_module.OpsService,
"get_tracing_app_config",
lambda **_kwargs: None,
)
with app.test_request_context("/?tracing_provider=langfuse"):
result = method(app_id="app-1")
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
monkeypatch.setattr(
ops_trace_module.OpsService,
"create_tracing_app_config",
lambda **_kwargs: {"error": True},
)
with app.test_request_context(
"/",
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
):
with pytest.raises(BadRequest):
method(app_id="app-1")
def test_trace_app_config_delete_not_found(self, app, monkeypatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
monkeypatch.setattr(
ops_trace_module.OpsService,
"delete_tracing_app_config",
lambda **_kwargs: False,
)
with app.test_request_context("/?tracing_provider=langfuse"):
with pytest.raises(BadRequest):
method(app_id="app-1")
# ========== Site Tests ==========
class TestSiteEndpoints:
"""Tests for site endpoint."""
def test_site_response_structure(self):
"""Test site response structure."""
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
assert payload.title == "My Site"
def test_site_default_language_validation(self):
payload = AppSiteUpdatePayload(default_language="en-US")
assert payload.default_language == "en-US"
def test_app_site_update_post(self, app, monkeypatch):
api = site_module.AppSite()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/", json={"title": "My Site"}):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
def test_app_site_access_token_reset(self, app, monkeypatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
site = MagicMock()
query = MagicMock()
query.where.return_value.first.return_value = site
monkeypatch.setattr(
site_module.db,
"session",
MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None),
)
monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code")
monkeypatch.setattr(
site_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is site
# ========== Workflow Tests ==========
class TestWorkflowEndpoints:
"""Tests for workflow endpoints."""
def test_workflow_copy_payload(self):
"""Test workflow copy payload."""
payload = SyncDraftWorkflowPayload(graph={}, features={})
assert payload.graph == {}
def test_workflow_mode_query(self):
"""Test workflow mode query."""
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
assert payload.query == "hi"
# ========== Workflow App Log Tests ==========
class TestWorkflowAppLogEndpoints:
"""Tests for workflow app log endpoints."""
def test_workflow_app_log_query(self):
"""Test workflow app log query."""
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
assert query.keyword == "test"
def test_workflow_app_log_query_detail_bool(self):
query = WorkflowAppLogQuery(detail="true")
assert query.detail is True
def test_workflow_app_log_api_get(self, app, monkeypatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession())
def fake_get_paginate(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(
workflow_app_log_module.WorkflowAppService,
"get_paginate_workflow_app_logs",
fake_get_paginate,
)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Draft Variable Tests ==========
class TestWorkflowDraftVariableEndpoints:
"""Tests for workflow draft variable endpoints."""
def test_workflow_variable_creation(self):
"""Test workflow variable creation."""
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
assert payload.name == "var1"
def test_workflow_variable_collection_get(self, app, monkeypatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
class DummySession:
def __enter__(self):
return "session"
def __exit__(self, exc_type, exc, tb):
return False
class DummyDraftService:
def __init__(self, session):
self.session = session
def list_variables_without_values(self, **_kwargs):
return {"items": [], "total": 0}
monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession())
class DummyWorkflowService:
def is_workflow_exist(self, *args, **kwargs):
return True
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService)
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
# ========== Workflow Statistic Tests ==========
class TestWorkflowStatisticEndpoints:
"""Tests for workflow statistic endpoints."""
def test_workflow_statistic_time_range(self):
"""Test workflow statistic time range query."""
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
assert query.start == "2024-01-01"
def test_workflow_statistic_blank_to_none(self):
query = WorkflowStatisticQuery(start="", end="")
assert query.start is None
assert query.end is None
def test_workflow_daily_runs_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
def test_workflow_daily_terminals_statistic(self, app, monkeypatch):
monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock()))
monkeypatch.setattr(
workflow_statistic_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: SimpleNamespace(
get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}]
),
)
monkeypatch.setattr(
workflow_statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
workflow_statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
method = _unwrap(api.get)
with app.test_request_context("/"):
response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
# ========== Workflow Trigger Tests ==========
class TestWorkflowTriggerEndpoints:
"""Tests for workflow trigger endpoints."""
def test_webhook_trigger_payload(self):
"""Test webhook trigger payload."""
payload = Parser(node_id="node-1")
assert payload.node_id == "node-1"
enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True)
assert enable_payload.enable_trigger is True
def test_webhook_trigger_api_get(self, app, monkeypatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
trigger = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = trigger
class DummySession:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession())
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
assert result is trigger
# ========== Wraps Tests ==========
class TestWrapsEndpoints:
"""Tests for wraps utility functions."""
def test_get_app_model_context(self):
"""Test get_app_model wrapper context."""
# These are decorator functions, so we test their availability
assert hasattr(wraps_module, "get_app_model")
# ========== MCP Server Tests ==========
class TestMCPServerEndpoints:
"""Tests for MCP server endpoints."""
def test_mcp_server_connection(self):
"""Test MCP server connection."""
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
assert payload.parameters["url"] == "http://localhost:3000"
def test_mcp_server_update_payload(self):
payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active")
assert payload.status == "active"
# ========== Error Handling Tests ==========
class TestErrorHandling:
"""Tests for error handling in various endpoints."""
def test_annotation_list_query_validation(self):
"""Test annotation list query validation."""
with pytest.raises(ValueError):
annotation_module.AnnotationListQuery(page=0)
# ========== Integration-like Tests ==========
class TestPayloadIntegration:
"""Integration tests for payload handling."""
def test_multiple_payload_types(self):
"""Test handling of multiple payload types."""
payloads = [
annotation_module.AnnotationReplyPayload(
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
),
message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"),
statistic_module.StatisticTimeRangeQuery(start="2024-01-01"),
]
assert len(payloads) == 3
assert all(p is not None for p in payloads)

View File

@@ -0,0 +1,157 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
self.app_id = app_id
def model_dump(self, mode: str = "json"):
return {"status": self.status, "app_id": self.app_id}
class _SessionContext:
def __init__(self, session):
self._session = session
def __enter__(self):
return self._session
def __exit__(self, exc_type, exc, tb):
return False
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
app_import_module.AppDslService,
"import_app",
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
)
update_access = MagicMock()
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method()
session.commit.assert_called_once()
update_access.assert_called_once_with("app-123", "private")
assert status == 200
assert response["status"] == ImportStatus.COMPLETED
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"confirm_import",
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
)
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(import_id="import-1")
session.commit.assert_called_once()
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
session = MagicMock()
_install_session(monkeypatch, session)
monkeypatch.setattr(
app_import_module.AppDslService,
"check_dependencies",
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@@ -0,0 +1,292 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from dify_graph.model_runtime.errors.invoke import InvokeError
from services.audio_service import AudioService
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
ProviderNotSupportTextToSpeechLanageServiceError,
UnsupportedAudioTypeServiceError,
)
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(app_model=app_model)
assert response == {"text": "ok"}
@pytest.mark.parametrize(
("exc", "expected"),
[
(AppModelConfigBrokenError(), AppUnavailableError),
(NoAudioUploadedServiceError(), NoAudioUploadedError),
(AudioTooLargeServiceError("too big"), AudioTooLargeError),
(UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError),
(ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError),
(ProviderTokenNotInitError("token"), ProviderNotInitializeError),
(QuotaExceededError(), ProviderQuotaExceededError),
(ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError),
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(app_model=app_model)
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(app_model=app_model)
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(app_model=app_model)
assert response == {"audio": "ok"}
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
api = ChatMessageTextApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
"/console/api/apps/app/text-to-audio",
method="POST",
json={"text": "hello"},
):
with pytest.raises(ProviderQuotaExceededError):
handler(app_model=app_model)
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
response = handler(app_model=app_model)
assert response == ["voice-1"]
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()),
)
api = TextModesApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
with pytest.raises(AppUnavailableError):
handler(app_model=app_model)
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import io
from types import SimpleNamespace
import pytest
from controllers.console.app import audio as audio_module
from controllers.console.app.error import AudioTooLargeError
from services.errors.audio import AudioTooLargeServiceError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_asr",
lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")),
)
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"x"), "sample.wav")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices",
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
app_model = SimpleNamespace(id="app-1")
data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")}
with app.test_request_context(
"/console/api/apps/app-1/audio-to-text",
method="POST",
data=data,
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
app_model = SimpleNamespace(id="app-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio",
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
monkeypatch.setattr(
audio_module.AudioService,
"transcript_tts_voices",
lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}],
)
app_model = SimpleNamespace(tenant_id="tenant-1")
with app.test_request_context(
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
assert isinstance(response, list)

View File

@@ -0,0 +1,130 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import conversation as conversation_module
from models.model import AppMode
from services.errors.conversation import ConversationNotExistsError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _make_account():
return SimpleNamespace(timezone="UTC", id="u1")
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response is paginate_result
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(
conversation_module,
"parse_time_range",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")),
)
with app.test_request_context(
"/console/api/apps/app-1/completion-conversations",
method="GET",
query_string={"start": "bad"},
):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.ChatConversationApi()
method = _unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1"))
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
paginate_result = MagicMock()
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
assert response is paginate_result
def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None:
conversation = SimpleNamespace(id="c1", app_id="app-1")
query = MagicMock()
query.where.return_value = query
query.first.return_value = conversation
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1")
assert result is conversation
session.execute.assert_called_once()
session.commit.assert_called_once()
session.refresh.assert_called_once_with(conversation)
def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
session = MagicMock()
session.query.return_value = query
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(conversation_module.db, "session", session)
with pytest.raises(NotFound):
conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing")
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationDetailApi()
method = _unwrap(api.delete)
monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1"))
monkeypatch.setattr(
conversation_module.ConversationService,
"delete",
lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()),
)
with pytest.raises(NotFound):
method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1")

View File

@@ -0,0 +1,260 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import generator as generator_module
from controllers.console.app.error import ProviderNotInitializeError
from core.errors.error import ProviderTokenNotInitError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _model_config_payload():
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
class _Service:
def get_draft_workflow(self, app_model):
return workflow
monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service())
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
with app.test_request_context(
"/console/api/rule-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
response = method()
assert response == {"rules": []}
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleCodeGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
def _raise(*_args, **_kwargs):
raise ProviderTokenNotInitError("missing token")
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise)
with app.test_request_context(
"/console/api/rule-code-generate",
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
with pytest.raises(ProviderNotInitializeError):
method()
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "app app-1 not found"
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
_install_workflow_service(monkeypatch, workflow=None)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "workflow app-1 not found"
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(graph_dict={"nodes": []})
_install_workflow_service(monkeypatch, workflow=workflow)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "node node-1 not found"
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
app_model = SimpleNamespace(id="app-1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
workflow = SimpleNamespace(
graph_dict={
"nodes": [
{"id": "node-1", "data": {"type": "code"}},
]
}
)
_install_workflow_service(monkeypatch, workflow=workflow)
monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"})
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "node-1",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"code": "x"}
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(
generator_module.LLMGenerator,
"instruction_modify_legacy",
lambda **_kwargs: {"instruction": "ok"},
)
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "old",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response = method()
assert response == {"instruction": "ok"}
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1"))
with app.test_request_context(
"/console/api/instruction-generate",
method="POST",
json={
"flow_id": "app-1",
"node_id": "",
"current": "",
"instruction": "do",
"model_config": _model_config_payload(),
},
):
response, status = method()
assert status == 400
assert response["error"] == "incompatible parameters"
def test_instruction_template_prompt(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "prompt"},
):
response = method()
assert "data" in response
def test_instruction_template_invalid_type(app) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "unknown"},
):
with pytest.raises(ValueError):
method()

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import pytest
from controllers.console.app import message as message_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test valid ChatMessagesQuery with all fields."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="550e8400-e29b-41d4-a716-446655440001",
limit=50,
)
assert query.limit == 50
def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery with defaults."""
query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000")
assert query.first_id is None
assert query.limit == 20
def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test ChatMessagesQuery converts empty first_id to None."""
query = message_module.ChatMessagesQuery(
conversation_id="550e8400-e29b-41d4-a716-446655440000",
first_id="",
)
assert query.first_id is None
def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with like rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="like",
content="Good answer",
)
assert payload.rating == "like"
assert payload.content == "Good answer"
def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload with dislike rating."""
payload = message_module.MessageFeedbackPayload(
message_id="550e8400-e29b-41d4-a716-446655440000",
rating="dislike",
)
assert payload.rating == "dislike"
def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test MessageFeedbackPayload without rating."""
payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000")
assert payload.rating is None
def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with default format."""
query = message_module.FeedbackExportQuery()
assert query.format == "csv"
assert query.from_source is None
def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with JSON format."""
query = message_module.FeedbackExportQuery(format="json")
assert query.format == "json"
def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as true string."""
query = message_module.FeedbackExportQuery(has_comment="true")
assert query.has_comment is True
def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as false string."""
query = message_module.FeedbackExportQuery(has_comment="false")
assert query.has_comment is False
def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 1."""
query = message_module.FeedbackExportQuery(has_comment="1")
assert query.has_comment is True
def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with has_comment as 0."""
query = message_module.FeedbackExportQuery(has_comment="0")
assert query.has_comment is False
def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test FeedbackExportQuery with rating filter."""
query = message_module.FeedbackExportQuery(rating="like")
assert query.rating == "like"
def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test AnnotationCountResponse creation."""
response = message_module.AnnotationCountResponse(count=10)
assert response.count == 10
def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test SuggestedQuestionsResponse creation."""
response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"])
assert len(response.data) == 2
assert response.data[0] == "What is AI?"

View File

@@ -0,0 +1,151 @@
from __future__ import annotations
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from controllers.console.app import model_config as model_config_module
from models.model import AppMode, AppModelConfig
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.CHAT.value,
is_agent=False,
app_model_config_id=None,
updated_by=None,
updated_at=None,
)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {"pre_prompt": "hi"},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
session = MagicMock()
monkeypatch.setattr(model_config_module.db, "session", session)
def _from_model_config_dict(self, model_config):
self.pre_prompt = model_config["pre_prompt"]
self.id = "config-1"
return self
monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
session.add.assert_called_once()
session.flush.assert_called_once()
session.commit.assert_called_once()
send_mock.assert_called_once()
assert app_model.app_model_config_id == "config-1"
assert response["result"] == "success"
def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
mode=AppMode.AGENT_CHAT.value,
is_agent=True,
app_model_config_id="config-0",
updated_by=None,
updated_at=None,
)
original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1")
original_config.agent_mode = json.dumps(
{
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
}
)
session = MagicMock()
query = MagicMock()
query.where.return_value = query
query.first.return_value = original_config
session.query.return_value = query
monkeypatch.setattr(model_config_module.db, "session", session)
monkeypatch.setattr(
model_config_module.AppModelConfigService,
"validate_configuration",
lambda **_kwargs: {
"pre_prompt": "hi",
"agent_mode": {
"enabled": True,
"strategy": "function-calling",
"tools": [
{
"provider_id": "provider",
"provider_type": "builtin",
"tool_name": "tool",
"tool_parameters": {"secret": "masked"},
}
],
"prompt": None,
},
},
)
monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object())
class _ParamManager:
def __init__(self, **_kwargs):
self.delete_called = False
def decrypt_tool_parameters(self, _value):
return {"secret": "decrypted"}
def mask_tool_parameters(self, _value):
return {"secret": "masked"}
def encrypt_tool_parameters(self, _value):
return {"secret": "encrypted"}
def delete_tool_parameters_cache(self):
self.delete_called = True
monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager)
send_mock = MagicMock()
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method(app_model=app_model)
stored_config = session.add.call_args[0][0]
stored_agent_mode = json.loads(stored_config.agent_mode)
assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted"
assert response["result"] == "success"

View File

@@ -0,0 +1,215 @@
from __future__ import annotations
from decimal import Decimal
from types import SimpleNamespace
import pytest
from werkzeug.exceptions import BadRequest
from controllers.console.app import statistic as statistic_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, _query, _args):
return self._rows
def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None:
engine = SimpleNamespace(begin=lambda: _ConnContext(rows))
monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine))
def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: (None, None),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 1
assert data["data"][0]["date"] == "2024-01-03"
assert data["data"][0]["token_count"] == 10
assert data["data"][0]["total_price"] == 0.25
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTerminalsStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that AverageSessionInteractionStatistic is limited to chat/agent modes."""
# This just verifies the decorator is applied correctly
# Actual endpoint testing would require complex JOIN mocking
api = statistic_module.AverageSessionInteractionStatistic()
method = _unwrap(api.get)
assert callable(method)
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
def mock_parse(*args, **kwargs):
raise ValueError("Invalid time range")
_install_db(monkeypatch, [])
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
with pytest.raises(BadRequest):
method(app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", message_count=10),
SimpleNamespace(date="2024-01-02", message_count=15),
SimpleNamespace(date="2024-01-03", message_count=12),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 3
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
_install_common(monkeypatch)
_install_db(monkeypatch, [])
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": []}
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_db(monkeypatch, rows)
monkeypatch.setattr(
statistic_module,
"current_account_with_tenant",
lambda: (SimpleNamespace(timezone="UTC"), "t1"),
)
monkeypatch.setattr(
statistic_module,
"parse_time_range",
lambda *_args, **_kwargs: ("s", "e"),
)
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"),
]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 2

View File

@@ -0,0 +1,163 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import HTTPException, NotFound
from controllers.console.app import workflow as workflow_module
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from dify_graph.file.enums import FileTransferMethod, FileType
from dify_graph.file.models import File
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == []
def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
config = object()
file_list = [
File(
tenant_id="t1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="http://u",
)
]
build_mock = Mock(return_value=file_list)
monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config)
monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock)
workflow = SimpleNamespace(features_dict={}, tenant_id="t1")
result = workflow_module._parse_file(workflow, files=[{"id": "f"}])
assert result == file_list
build_mock.assert_called_once()
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
with pytest.raises(HTTPException) as exc:
handler(api, app_model=SimpleNamespace(id="app"))
assert exc.value.code == 415
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
data="[]",
content_type="application/json",
):
response, status = handler(api, app_model=SimpleNamespace(id="app"))
assert status == 400
assert response["message"] == "Invalid JSON data"
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="h",
updated_at=None,
created_at=datetime(2024, 1, 1),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
monkeypatch.setattr(
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
)
monkeypatch.setattr(
workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv"
)
service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
response = handler(api, app_model=SimpleNamespace(id="app"))
assert response["result"] == "success"
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
def _raise(*_args, **_kwargs):
raise workflow_module.WorkflowHashNotEqualError()
service = SimpleNamespace(sync_draft_workflow=_raise)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
with pytest.raises(DraftWorkflowNotSync):
handler(api, app_model=SimpleNamespace(id="app"))
def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None)
)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.get)
with pytest.raises(DraftWorkflowNotExist):
handler(api, app_model=SimpleNamespace(id="app"))
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module.AppGenerateService,
"generate",
lambda *_args, **_kwargs: (_ for _ in ()).throw(
workflow_module.services.errors.conversation.ConversationNotExistsError()
),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/apps/app/advanced-chat/workflows/draft/run",
method="POST",
json={"inputs": {}},
):
with pytest.raises(NotFound):
handler(api, app_model=SimpleNamespace(id="app"))

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from types import SimpleNamespace
import pytest
from controllers.console.app import wraps as wraps_module
from controllers.console.app.error import AppNotFoundError
from models.model import AppMode
def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
assert handler(app_id="app-1") == "app-1"
def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None:
app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1")
query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model)
monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1"))
monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query))
@wraps_module.get_app_model(mode=[AppMode.COMPLETION])
def handler(app_model):
return app_model.id
with pytest.raises(AppNotFoundError):
handler(app_id="app-1")
def test_get_app_model_requires_app_id() -> None:
@wraps_module.get_app_model
def handler(app_model):
return app_model.id
with pytest.raises(ValueError):
handler()

View File

@@ -0,0 +1,817 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_auth import (
DatasourceAuth,
DatasourceAuthDefaultApi,
DatasourceAuthDeleteApi,
DatasourceAuthListApi,
DatasourceAuthOauthCustomClient,
DatasourceAuthUpdateApi,
DatasourceHardCodeAuthListApi,
DatasourceOAuthCallback,
DatasourcePluginOAuthAuthorizationUrl,
DatasourceUpdateProviderNameApi,
)
from core.plugin.impl.oauth import OAuthHandler
from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/?credential_id=cred-1"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-1",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthProxyService,
"create_proxy_context",
return_value="ctx-123",
),
patch.object(
OAuthHandler,
"get_authorization_url",
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
assert response.status_code == 200
assert "context_id" in response.headers.get("Set-Cookie")
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {"name": "test"}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
def test_callback_missing_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
with (
app.test_request_context("/?context_id=bad"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=None,
),
):
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
context = {"user_id": "u", "tenant_id": "t"}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {} # avatar + name missing
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": "cred-1",
}
with (
app.test_request_context("/?context_id=ctx"),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"reauthorize_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
oauth_response.expires_at = None
oauth_response.metadata = {}
context = {
"user_id": "user-1",
"tenant_id": "tenant-1",
"credential_id": None,
}
with (
app.test_request_context("/", headers={"Cookie": "context_id=ctx"}),
patch.object(
OAuthProxyService,
"use_proxy_context",
return_value=context,
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
return_value={"client_id": "abc"},
),
patch.object(
OAuthHandler,
"get_credentials",
return_value=oauth_response,
),
patch.object(
DatasourceProviderService,
"add_datasource_oauth_provider",
return_value=None,
),
):
response = method(api, "notion")
assert response.status_code == 302
class TestDatasourceAuth:
def test_post_success(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "val"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_invalid_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {"credentials": {"key": "bad"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
side_effect=CredentialsValidateFailedError("invalid"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
api = DatasourceAuth()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
api = DatasourceAuth()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[],
),
):
response, status = method(api, "notion")
assert status == 200
assert response["result"] == []
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {"credential_id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_missing_credential_id(self, app):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {"k": "v"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 201
def test_update_with_credentials_none(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": None}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
response, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "credentials": {}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
_, status = method(api, "notion")
update_mock.assert_called_once()
assert status == 201
class TestDatasourceAuthListApi:
def test_list_success(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
def test_auth_list_empty(self, app):
api = DatasourceAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
assert status == 200
assert response["result"] == []
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
assert status == 200
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {"client_params": {}, "enable_oauth_custom_client": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_delete_success(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_post_empty_payload(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
_, status = method(api, "notion")
assert status == 200
def test_post_disabled_flag(self, app):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
payload = {
"client_params": {"a": 1},
"enable_oauth_custom_client": False,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
) as setup_mock,
):
_, status = method(api, "notion")
setup_mock.assert_called_once()
assert status == 200
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {"id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"set_default_datasource_provider",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_default_missing_id(self, app):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_provider_name",
return_value=None,
),
):
response, status = method(api, "notion")
assert status == 200
def test_update_name_too_long(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {
"credential_id": "id",
"name": "x" * 101,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
payload = {"name": "Valid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")

View File

@@ -0,0 +1,143 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.datasource_content_preview import (
DataSourceContentPreviewApi,
)
from models import Account
from models.dataset import Pipeline
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDataSourceContentPreviewApi:
def _valid_payload(self):
return {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": "cred-1",
}
def test_post_success(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
node_id = "node-1"
account = MagicMock(spec=Account)
preview_result = {"content": "preview data"}
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = preview_result
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, node_id)
service_instance.run_datasource_node_preview.assert_called_once_with(
pipeline=pipeline,
node_id=node_id,
user_inputs=payload["inputs"],
account=account,
datasource_type=payload["datasource_type"],
is_published=True,
credential_id=payload["credential_id"],
)
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = self._valid_payload()
pipeline = MagicMock(spec=Pipeline)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
MagicMock(), # NOT Account
),
):
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
# datasource_type missing
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
payload = {
"inputs": {"query": "hello"},
"datasource_type": "notion",
"credential_id": None,
}
pipeline = MagicMock(spec=Pipeline)
account = MagicMock(spec=Account)
service_instance = MagicMock()
service_instance.run_datasource_node_preview.return_value = {"ok": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user",
account,
),
patch(
"controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService",
return_value=service_instance,
),
):
response, status = method(api, pipeline, "node-1")
service_instance.run_datasource_node_preview.assert_called_once()
assert status == 200
assert response == {"ok": True}

View File

@@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
CustomizedPipelineTemplateApi,
PipelineTemplateDetailApi,
PipelineTemplateListApi,
PublishCustomizedPipelineTemplateApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestPipelineTemplateListApi:
def test_get_success(self, app):
api = PipelineTemplateListApi()
method = unwrap(api.get)
templates = [{"id": "t1"}]
with (
app.test_request_context("/?type=built-in&language=en-US"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates",
return_value=templates,
),
):
response, status = method(api)
assert status == 200
assert response == templates
class TestPipelineTemplateDetailApi:
def test_get_success(self, app):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
template = {"id": "tpl-1"}
service = MagicMock()
service.get_pipeline_template_detail.return_value = template
with (
app.test_request_context("/?type=built-in"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == template
class TestCustomizedPipelineTemplateApi:
def test_patch_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template"
) as update_mock,
):
response = method(api, "tpl-1")
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template"
) as delete_mock,
):
response = method(api, "tpl-1")
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
template = MagicMock()
template.yaml_content = "yaml-data"
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = template
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
response, status = method(api, "tpl-1")
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
fake_db = MagicMock()
fake_db.engine = MagicMock()
session = MagicMock()
session.query.return_value.where.return_value.first.return_value = None
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
return_value=session_ctx,
),
):
with pytest.raises(ValueError):
method(api, "tpl-1")
class TestPublishCustomizedPipelineTemplateApi:
def test_post_success(self, app):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)
payload = {
"name": "Template",
"description": "Desc",
"icon_info": {"icon": "📘"},
}
service = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService",
return_value=service,
),
):
response = method(api, "pipeline-1")
service.publish_customized_pipeline_template.assert_called_once()
assert response == {"result": "success"}

View File

@@ -0,0 +1,187 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import (
CreateEmptyRagPipelineDatasetApi,
CreateRagPipelineDatasetApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCreateRagPipelineDatasetApi:
def _valid_payload(self):
return {"yaml_content": "name: test"}
def test_post_success(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
import_info = {"dataset_id": "ds-1"}
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.return_value = import_info
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
response, status = method(api)
assert status == 201
assert response == import_info
def test_post_forbidden_non_editor(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
def test_post_dataset_name_duplicate(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = self._valid_payload()
user = MagicMock(is_dataset_editor=True)
mock_service = MagicMock()
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
mock_session_ctx = MagicMock()
mock_session_ctx.__enter__.return_value = MagicMock()
mock_session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
return_value=mock_session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
return_value=mock_service,
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload(self, app):
api = CreateRagPipelineDatasetApi()
method = unwrap(api.post)
payload = {}
user = MagicMock(is_dataset_editor=True)
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api)
class TestCreateEmptyRagPipelineDatasetApi:
def test_post_success(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal",
return_value={"id": "ds-1"},
),
):
response, status = method(api)
assert status == 201
assert response == {"id": "ds-1"}
def test_post_forbidden_non_editor(self, app):
api = CreateEmptyRagPipelineDatasetApi()
method = unwrap(api.post)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)

View File

@@ -0,0 +1,324 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Response
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import (
RagPipelineEnvironmentVariableCollectionApi,
RagPipelineNodeVariableCollectionApi,
RagPipelineSystemVariableCollectionApi,
RagPipelineVariableApi,
RagPipelineVariableCollectionApi,
RagPipelineVariableResetApi,
)
from controllers.web.error import InvalidArgumentError, NotFoundError
from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.types import SegmentType
from models.account import Account
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def fake_db():
db = MagicMock()
db.engine = MagicMock()
db.session.return_value = MagicMock()
return db
@pytest.fixture
def editor_user():
user = MagicMock(spec=Account)
user.has_edit_permission = True
return user
@pytest.fixture
def restx_config(app):
return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"})
class TestRagPipelineVariableCollectionApi:
def test_get_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = True
# IMPORTANT: RESTX expects .variables
var_list = MagicMock()
var_list.variables = []
draft_srv = MagicMock()
draft_srv.list_variables_without_values.return_value = var_list
with (
app.test_request_context("/?page=1&limit=10"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=draft_srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock()
rag_srv = MagicMock()
rag_srv.is_workflow_exist.return_value = False
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_delete_variables_success(self, app, fake_db, editor_user):
api = RagPipelineVariableCollectionApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"),
):
result = method(api, pipeline)
assert isinstance(result, Response)
assert result.status_code == 204
class TestRagPipelineNodeVariableCollectionApi:
def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_node_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "node1")
assert result["items"] == []
def test_get_node_variables_invalid_node(self, app, editor_user):
api = RagPipelineNodeVariableCollectionApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
):
with pytest.raises(InvalidArgumentError):
method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID)
class TestRagPipelineVariableApi:
def test_get_variable_not_found(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.get)
srv = MagicMock()
srv.get_variable.return_value = None
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(NotFoundError):
method(api, MagicMock(), "v1")
def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.patch)
pipeline = MagicMock(id="p1", tenant_id="t1")
variable = MagicMock(app_id="p1", value_type=SegmentType.FILE)
srv = MagicMock()
srv.get_variable.return_value = variable
payload = {"value": "invalid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
with pytest.raises(InvalidArgumentError):
method(api, pipeline, "v1")
def test_delete_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableApi()
method = unwrap(api.delete)
pipeline = MagicMock(id="p1")
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline, "v1")
assert result.status_code == 204
class TestRagPipelineVariableResetApi:
def test_reset_variable_success(self, app, fake_db, editor_user):
api = RagPipelineVariableResetApi()
method = unwrap(api.put)
pipeline = MagicMock(id="p1")
workflow = MagicMock()
variable = MagicMock(app_id="p1")
srv = MagicMock()
srv.get_variable.return_value = variable
srv.reset_variable.return_value = variable
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal",
return_value={"id": "v1"},
),
):
result = method(api, pipeline, "v1")
assert result == {"id": "v1"}
class TestSystemAndEnvironmentVariablesApi:
def test_system_variables_success(self, app, fake_db, editor_user, restx_config):
api = RagPipelineSystemVariableCollectionApi()
method = unwrap(api.get)
pipeline = MagicMock(id="p1")
var_list = MagicMock()
var_list.variables = []
srv = MagicMock()
srv.list_system_variables.return_value = var_list
with (
app.test_request_context("/"),
restx_config,
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService",
return_value=srv,
),
):
result = method(api, pipeline)
assert result["items"] == []
def test_environment_variables_success(self, app, editor_user):
api = RagPipelineEnvironmentVariableCollectionApi()
method = unwrap(api.get)
env_var = MagicMock(
id="e1",
name="ENV",
description="d",
selector="s",
value_type=MagicMock(value="string"),
value="x",
)
workflow = MagicMock(environment_variables=[env_var])
pipeline = MagicMock(id="p1")
rag_srv = MagicMock()
rag_srv.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService",
return_value=rag_srv,
),
):
result = method(api, pipeline)
assert len(result["items"]) == 1

View File

@@ -0,0 +1,329 @@
from unittest.mock import MagicMock, patch
from controllers.console import console_ns
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
RagPipelineExportApi,
RagPipelineImportApi,
RagPipelineImportCheckDependenciesApi,
RagPipelineImportConfirmApi,
)
from models.dataset import Pipeline
from services.app_dsl_service import ImportStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestRagPipelineImportApi:
def _payload(self, mode="create"):
return {
"mode": mode,
"yaml_content": "content",
"name": "Test",
}
def test_post_success_200(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"status": "success"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"status": "failed"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
api = RagPipelineImportApi()
method = unwrap(api.post)
payload = self._payload()
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.PENDING
result.model_dump.return_value = {"status": "pending"}
service = MagicMock()
service.import_rag_pipeline.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api)
assert status == 202
assert response == {"status": "pending"}
class TestRagPipelineImportConfirmApi:
def test_confirm_success(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = "completed"
result.model_dump.return_value = {"ok": True}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
user = MagicMock()
result = MagicMock()
result.status = ImportStatus.FAILED
result.model_dump.return_value = {"ok": False}
service = MagicMock()
service.confirm_import.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, "import-1")
assert status == 400
assert response == {"ok": False}
class TestRagPipelineImportCheckDependenciesApi:
def test_get_success(self, app):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
result = MagicMock()
result.model_dump.return_value = {"deps": []}
service = MagicMock()
service.check_dependencies.return_value = result
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"deps": []}
class TestRagPipelineExportApi:
def test_get_with_include_secret(self, app):
api = RagPipelineExportApi()
method = unwrap(api.get)
pipeline = MagicMock(spec=Pipeline)
service = MagicMock()
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
fake_db = MagicMock()
fake_db.engine = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = MagicMock()
session_ctx.__exit__.return_value = None
with (
app.test_request_context("/?include_secret=true"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
return_value=service,
),
):
response, status = method(api, pipeline)
assert status == 200
assert response == {"data": {"yaml": "data"}}

View File

@@ -0,0 +1,688 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import (
DefaultRagPipelineBlockConfigApi,
DraftRagPipelineApi,
DraftRagPipelineRunApi,
PublishedAllRagPipelineApi,
PublishedRagPipelineApi,
PublishedRagPipelineRunApi,
RagPipelineByIdApi,
RagPipelineDatasourceVariableApi,
RagPipelineDraftNodeRunApi,
RagPipelineDraftRunIterationNodeApi,
RagPipelineDraftRunLoopNodeApi,
RagPipelineRecommendedPluginApi,
RagPipelineTaskStopApi,
RagPipelineTransformApi,
RagPipelineWorkflowLastRunApi,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDraftWorkflowApi:
def test_get_draft_success(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result == workflow
def test_get_draft_not_exist(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotExist):
method(api, pipeline)
def test_sync_hash_not_match(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError()
with (
app.test_request_context("/", json={"graph": {}, "features": {}}),
patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(DraftWorkflowNotSync):
method(api, pipeline)
def test_sync_invalid_text_plain(self, app):
api = DraftRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
response, status = method(api, pipeline)
assert status == 400
class TestDraftRunNodes:
def test_iteration_node_success(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline, "node")
assert result == {"ok": True}
def test_iteration_node_conversation_not_exists(self, app):
api = RagPipelineDraftRunIterationNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration",
side_effect=services.errors.conversation.ConversationNotExistsError(),
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node")
def test_loop_node_success(self, app):
api = RagPipelineDraftRunLoopNodeApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline, "node") == {"ok": True}
class TestPipelineRunApis:
def test_draft_run_success(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
assert method(api, pipeline) == {"ok": True}
def test_draft_run_rate_limit(self, app):
api = DraftRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context(
"/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}
),
patch.object(
type(console_ns),
"payload",
{"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"},
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDraftNodeRun:
def test_execution_not_found(self, app):
api = RagPipelineDraftNodeRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
service = MagicMock()
service.run_draft_workflow_node.return_value = None
with (
app.test_request_context("/", json={"inputs": {}}),
patch.object(type(console_ns), "payload", {"inputs": {}}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(ValueError):
method(api, pipeline, "node")
class TestPublishedPipelineApis:
def test_publish_success(self, app):
api = PublishedRagPipelineApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
workflow = MagicMock(
id="w1",
created_at=datetime.utcnow(),
)
session = MagicMock()
session.merge.return_value = pipeline
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
service = MagicMock()
service.publish_workflow.return_value = workflow
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["result"] == "success"
assert "created_at" in result
class TestMiscApis:
def test_task_stop(self, app):
api = RagPipelineTaskStopApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag"
) as stop_mock,
):
result = method(api, pipeline, "task-1")
stop_mock.assert_called_once()
assert result["result"] == "success"
def test_transform_forbidden(self, app):
api = RagPipelineTransformApi()
method = unwrap(api.post)
user = MagicMock(has_edit_permission=False, is_dataset_operator=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, "ds1")
def test_recommended_plugins(self, app):
api = RagPipelineRecommendedPluginApi()
method = unwrap(api.get)
service = MagicMock()
service.get_recommended_plugins.return_value = [{"id": "p1"}]
with (
app.test_request_context("/?type=all"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api)
assert result == [{"id": "p1"}]
class TestPublishedRagPipelineRunApi:
def test_published_run_success(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
"response_mode": "blocking",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response",
return_value={"ok": True},
),
):
result = method(api, pipeline)
assert result == {"ok": True}
def test_published_run_rate_limit(self, app):
api = PublishedRagPipelineRunApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"inputs": {},
"datasource_type": "x",
"datasource_info_list": [],
"start_node_id": "n",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate",
side_effect=InvokeRateLimitError("limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(api, pipeline)
class TestDefaultBlockConfigApi:
def test_get_block_config_success(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_default_block_config.return_value = {"k": "v"}
with (
app.test_request_context("/?q={}"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "llm")
assert result == {"k": "v"}
def test_get_block_config_invalid_json(self, app):
api = DefaultRagPipelineBlockConfigApi()
method = unwrap(api.get)
pipeline = MagicMock()
with app.test_request_context("/?q=bad-json"):
with pytest.raises(ValueError):
method(api, pipeline, "llm")
class TestPublishedAllRagPipelineApi:
def test_get_published_workflows_success(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
service = MagicMock()
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result["items"] == [{"id": "w1"}]
assert result["has_more"] is False
def test_get_published_workflows_forbidden(self, app):
api = PublishedAllRagPipelineApi()
method = unwrap(api.get)
pipeline = MagicMock()
user = MagicMock(id="u1")
with (
app.test_request_context("/?user_id=u2"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
with pytest.raises(Forbidden):
method(api, pipeline)
class TestRagPipelineByIdApi:
def test_patch_success(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock(tenant_id="t1")
user = MagicMock(id="u1")
workflow = MagicMock()
service = MagicMock()
service.update_workflow.return_value = workflow
session = MagicMock()
session_ctx = MagicMock()
session_ctx.__enter__.return_value = session
session_ctx.__exit__.return_value = None
fake_db = MagicMock()
fake_db.engine = MagicMock()
payload = {"marked_name": "test"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
fake_db,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
return_value=session_ctx,
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "w1")
assert result == workflow
def test_patch_no_fields(self, app):
api = RagPipelineByIdApi()
method = unwrap(api.patch)
pipeline = MagicMock()
user = MagicMock()
with (
app.test_request_context("/", json={}),
patch.object(type(console_ns), "payload", {}),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
):
result, status = method(api, pipeline, "w1")
assert status == 400
class TestRagPipelineWorkflowLastRunApi:
def test_last_run_success(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
workflow = MagicMock()
node_exec = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = workflow
service.get_node_last_run.return_value = node_exec
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline, "node1")
assert result == node_exec
def test_last_run_not_found(self, app):
api = RagPipelineWorkflowLastRunApi()
method = unwrap(api.get)
pipeline = MagicMock()
service = MagicMock()
service.get_draft_workflow.return_value = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
with pytest.raises(NotFound):
method(api, pipeline, "node1")
class TestRagPipelineDatasourceVariableApi:
def test_set_datasource_variables_success(self, app):
api = RagPipelineDatasourceVariableApi()
method = unwrap(api.post)
pipeline = MagicMock()
user = MagicMock()
payload = {
"datasource_type": "db",
"datasource_info": {},
"start_node_id": "n1",
"start_node_title": "Node",
}
service = MagicMock()
service.set_datasource_variables.return_value = MagicMock()
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
return_value=(user, "t"),
),
patch(
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
return_value=service,
),
):
result = method(api, pipeline)
assert result is not None

View File

@@ -0,0 +1,444 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import NotFound
from controllers.console.datasets import data_source
from controllers.console.datasets.data_source import (
DataSourceApi,
DataSourceNotionApi,
DataSourceNotionDatasetSyncApi,
DataSourceNotionDocumentSyncApi,
DataSourceNotionListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def tenant_ctx():
return (MagicMock(id="u1"), "tenant-1")
@pytest.fixture
def patch_tenant(tenant_ctx):
with patch(
"controllers.console.datasets.data_source.current_account_with_tenant",
return_value=tenant_ctx,
):
yield
@pytest.fixture
def mock_engine():
with patch.object(
type(data_source.db),
"engine",
new_callable=PropertyMock,
return_value=MagicMock(),
):
yield
class TestDataSourceApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
binding = MagicMock(
id="b1",
provider="notion",
created_at="now",
disabled=False,
source_info={},
)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: [binding]),
),
):
response, status = method(api)
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.db.session.scalars",
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api)
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "enable")
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "disable")
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "disable")
class TestDataSourceNotionListApi:
def test_get_credential_not_found(self, app, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
response, status = method(api)
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
page = MagicMock(
page_id="p1",
page_name="Page 1",
type="page",
parent_id="parent",
page_icon=None,
)
online_document_message = MagicMock(
result=[
MagicMock(
workspace_id="w1",
workspace_name="My Workspace",
workspace_icon="icon",
pages=[page],
)
]
)
dataset = MagicMock(data_source_type="notion_import")
document = MagicMock(data_source_info='{"notion_page_id": "p1"}')
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session") as mock_session_class,
patch(
"core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime",
return_value=MagicMock(
get_online_document_pages=lambda **kw: iter([online_document_message]),
datasource_provider_type=lambda: None,
),
),
):
mock_session = MagicMock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = [document]
response, status = method(api)
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
dataset = MagicMock(data_source_type="other_type")
with (
app.test_request_context("/?credential_id=c1&dataset_id=ds1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"token": "t"},
),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.data_source.Session"),
):
with pytest.raises(ValueError):
method(api)
class TestDataSourceNotionApi:
def test_get_preview_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
with (
app.test_request_context("/?credential_id=c1"),
patch(
"controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials",
return_value={"integration_secret": "t"},
),
patch(
"controllers.console.datasets.data_source.NotionExtractor",
return_value=extractor,
),
):
response, status = method(api, "p1", "page")
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
payload = {
"notion_info_list": [
{
"workspace_id": "w1",
"credential_id": "c1",
"pages": [{"page_id": "p1", "type": "page"}],
}
],
"process_rule": {"rules": {}},
"doc_form": "text_model",
"doc_language": "English",
}
with (
app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}),
patch(
"controllers.console.datasets.data_source.DocumentService.estimate_args_validate",
),
patch(
"controllers.console.datasets.data_source.IndexingRunner.indexing_estimate",
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
),
):
response, status = method(api)
assert status == 200
class TestDataSourceNotionDatasetSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id",
return_value=[MagicMock(id="d1")],
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1")
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
class TestDataSourceNotionDocumentSyncApi:
def test_get_success(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.document_indexing_sync_task.delay",
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.data_source.DatasetService.get_dataset",
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.data_source.DocumentService.get_document",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,399 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.datasets.error import DatasetNameDuplicateError
from controllers.console.datasets.external import (
BedrockRetrievalApi,
ExternalApiTemplateApi,
ExternalApiTemplateListApi,
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_external_dataset")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
user.is_dataset_editor = True
user.has_edit_permission = True
user.is_dataset_operator = True
return user
@pytest.fixture(autouse=True)
def mock_auth(mocker, current_user):
mocker.patch(
"controllers.console.datasets.external.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
)
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
api_item = MagicMock()
api_item.to_dict.return_value = {"id": "1"}
with (
app.test_request_context("/?page=1&limit=20"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
):
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(ExternalDatasetService, "validate_api_list"),
patch.object(
ExternalDatasetService,
"create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
ExternalDatasetService,
"get_external_knowledge_api",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
api = ExternalApiTemplateApi()
method = unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "api-id")
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
dataset = MagicMock()
dataset.embedding_available = False
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.enable_qa = False
dataset.enable_vector_store = False
dataset.vector_store_setting = None
dataset.is_multimodal = False
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetService,
"create_external_dataset",
return_value=dataset,
),
):
_, status = method(api)
assert status == 201
def test_create_forbidden(self, app, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
"external_knowledge_id": "kid",
"name": "dataset",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {"query": "hello"}
dataset = MagicMock()
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission"),
patch.object(
HitTestingService,
"external_retrieve",
return_value={"ok": True},
),
):
resp = method(api, "dataset-id")
assert resp["ok"] is True
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "hello",
"knowledge_id": "kid",
}
with (
app.test_request_context("/"),
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
patch.object(
ExternalDatasetTestService,
"knowledge_retrieval",
return_value={"ok": True},
),
):
resp, status = method()
assert status == 200
assert resp["ok"] is True
class TestExternalApiTemplateListApiAdvanced:
def test_post_duplicate_name_error(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api",
side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_get_with_pagination(self, app, mock_auth, current_user):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
with (
app.test_request_context("/?page=1&limit=20"),
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
class TestExternalDatasetCreateApiAdvanced:
def test_create_forbidden(self, app, mock_auth, current_user):
"""Test creating external dataset without permission"""
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
current_user.is_dataset_editor = False
payload = {
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "ek-1",
"name": "new_dataset",
"description": "A dataset",
}
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(Forbidden):
method(api)
class TestExternalKnowledgeHitTestingApiAdvanced:
def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user):
"""Test hit testing on non-existent dataset"""
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test query",
"external_retrieval_model": None,
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
dataset = MagicMock()
payload = {
"query": "test query",
"external_retrieval_model": {"type": "bm25"},
"metadata_filtering_conditions": {"status": "active"},
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.DatasetService.get_dataset",
return_value=dataset,
),
patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"),
patch(
"controllers.console.datasets.external.HitTestingService.external_retrieve",
return_value={"results": []},
),
):
resp = method(api, "ds-1")
assert resp["results"] == []
class TestBedrockRetrievalApiAdvanced:
def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user):
api = BedrockRetrievalApi()
method = unwrap(api.post)
payload = {
"retrieval_setting": {},
"query": "test",
"knowledge_id": "k-1",
}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval",
side_effect=ValueError("Invalid settings"),
),
):
with pytest.raises(ValueError):
method()

View File

@@ -0,0 +1,160 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.hit_testing import HitTestingApi
from controllers.console.datasets.hit_testing_base import HitTestingPayload
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_hit_testing")
app.config["TESTING"] = True
return app
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass all decorators on the API method."""
mocker.patch(
"controllers.console.datasets.hit_testing.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.login_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.account_initialization_required",
return_value=lambda f: f,
)
mocker.patch(
"controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check",
return_value=lambda *_: (lambda f: f),
)
class TestHitTestingApi:
def test_hit_testing_success(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "what is vector search",
"top_k": 3,
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
),
patch.object(
HitTestingApi,
"perform_hit_testing",
return_value={"query": "what is vector search", "records": []},
),
):
result = method(api, dataset_id)
assert "query" in result
assert "records" in result
assert result["records"] == []
def test_hit_testing_dataset_not_found(self, app, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "test",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
side_effect=NotFound("Dataset not found"),
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_hit_testing_invalid_args(self, app, dataset, dataset_id):
api = HitTestingApi()
method = unwrap(api.post)
payload = {
"query": "",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch.object(
HitTestingPayload,
"model_validate",
return_value=MagicMock(model_dump=lambda **_: payload),
),
patch.object(
HitTestingApi,
"get_and_validate_dataset",
return_value=dataset,
),
patch.object(
HitTestingApi,
"hit_testing_args_check",
side_effect=ValueError("Invalid parameters"),
),
):
with pytest.raises(ValueError, match="Invalid parameters"):
method(api, dataset_id)

View File

@@ -0,0 +1,207 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.datasets.hit_testing_base import (
DatasetsHitTestingBase,
)
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from dify_graph.model_runtime.errors.invoke import InvokeError
from models.account import Account
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
@pytest.fixture
def account():
acc = MagicMock(spec=Account)
return acc
@pytest.fixture(autouse=True)
def patch_current_user(mocker, account):
"""Patch current_user to a valid Account."""
mocker.patch(
"controllers.console.datasets.hit_testing_base.current_user",
account,
)
@pytest.fixture
def dataset():
return MagicMock(id="dataset-1")
class TestGetAndValidateDataset:
def test_success(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
):
result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
assert result == dataset
def test_dataset_not_found(self):
with patch.object(
DatasetService,
"get_dataset",
return_value=None,
):
with pytest.raises(NotFound, match="Dataset not found"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
def test_permission_denied(self, dataset):
with (
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden, match="no access"):
DatasetsHitTestingBase.get_and_validate_dataset("dataset-1")
class TestHitTestingArgsCheck:
def test_args_check_called(self):
args = {"query": "test"}
with patch.object(
HitTestingService,
"hit_testing_args_check",
) as check_mock:
DatasetsHitTestingBase.hit_testing_args_check(args)
check_mock.assert_called_once_with(args)
class TestParseArgs:
def test_parse_args_success(self):
payload = {"query": "hello"}
result = DatasetsHitTestingBase.parse_args(payload)
assert result["query"] == "hello"
def test_parse_args_invalid(self):
payload = {"query": "x" * 300}
with pytest.raises(ValueError):
DatasetsHitTestingBase.parse_args(payload)
class TestPerformHitTesting:
def test_success(self, dataset):
response = {
"query": "hello",
"records": [],
}
with patch.object(
HitTestingService,
"retrieve",
return_value=response,
):
result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
assert result["query"] == "hello"
assert result["records"] == []
def test_index_not_initialized(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=services.errors.index.IndexNotInitializedError(),
):
with pytest.raises(DatasetNotInitializedError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_provider_token_not_init(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ProviderTokenNotInitError("token missing"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_quota_exceeded(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=QuotaExceededError(),
):
with pytest.raises(ProviderQuotaExceededError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_model_not_supported(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ModelCurrentlyNotSupportError(),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_llm_bad_request(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=LLMBadRequestError("bad request"),
):
with pytest.raises(ProviderNotInitializeError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_invoke_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=InvokeError("invoke failed"),
):
with pytest.raises(CompletionRequestError):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_value_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=ValueError("bad args"),
):
with pytest.raises(ValueError, match="bad args"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})
def test_unexpected_error(self, dataset):
with patch.object(
HitTestingService,
"retrieve",
side_effect=Exception("boom"),
):
with pytest.raises(InternalServerError, match="boom"):
DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"})

View File

@@ -0,0 +1,362 @@
import uuid
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
from controllers.console.datasets.metadata import (
DatasetMetadataApi,
DatasetMetadataBuiltInFieldActionApi,
DatasetMetadataBuiltInFieldApi,
DatasetMetadataCreateApi,
DocumentMetadataEditApi,
)
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
MetadataArgs,
MetadataOperationData,
)
from services.metadata_service import MetadataService
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_dataset_metadata")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
user.id = "user-1"
return user
@pytest.fixture
def dataset():
ds = MagicMock()
ds.id = "dataset-1"
return ds
@pytest.fixture
def dataset_id():
return uuid.uuid4()
@pytest.fixture
def metadata_id():
return uuid.uuid4()
@pytest.fixture(autouse=True)
def bypass_decorators(mocker):
"""Bypass setup/login/license decorators."""
mocker.patch(
"controllers.console.datasets.metadata.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.account_initialization_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.metadata.enterprise_license_required",
lambda f: f,
)
class TestDatasetMetadataCreateApi:
def test_create_metadata_success(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
payload = {"name": "author"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"create_metadata",
return_value={"id": "m1", "name": "author"},
),
):
result, status = method(api, dataset_id)
assert status == 201
assert result["name"] == "author"
def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.post)
valid_payload = {
"type": "string",
"name": "author",
}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=valid_payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
MetadataArgs,
"model_validate",
return_value=MagicMock(),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
class TestDatasetMetadataGetApi:
def test_get_metadata_success(self, app, dataset, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
MetadataService,
"get_dataset_metadatas",
return_value=[{"id": "m1"}],
),
):
result, status = method(api, dataset_id)
assert status == 200
assert isinstance(result, list)
def test_get_metadata_dataset_not_found(self, app, dataset_id):
api = DatasetMetadataCreateApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, dataset_id)
class TestDatasetMetadataApi:
def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.patch)
payload = {"name": "updated-name"}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"update_metadata_name",
return_value={"id": "m1", "name": "updated-name"},
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 200
assert result["name"] == "updated-name"
def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id):
api = DatasetMetadataApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"delete_metadata",
),
):
result, status = method(api, dataset_id, metadata_id)
assert status == 204
assert result["result"] == "success"
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch.object(
MetadataService,
"get_built_in_fields",
return_value=["title", "source"],
),
):
result, status = method(api)
assert status == 200
assert result["fields"] == ["title", "source"]
class TestDatasetMetadataBuiltInFieldActionApi:
def test_enable_built_in_field(self, app, current_user, dataset, dataset_id):
api = DatasetMetadataBuiltInFieldActionApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataService,
"enable_built_in_field",
),
):
result, status = method(api, dataset_id, "enable")
assert status == 200
assert result["result"] == "success"
class TestDocumentMetadataEditApi:
def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id):
api = DocumentMetadataEditApi()
method = unwrap(api.post)
payload = {"operation": "add", "metadata": {}}
with (
app.test_request_context("/"),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.metadata.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
),
patch.object(
MetadataOperationData,
"model_validate",
return_value=MagicMock(),
),
patch.object(
MetadataService,
"update_documents_metadata",
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result["result"] == "success"

View File

@@ -0,0 +1,233 @@
from unittest.mock import Mock, PropertyMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.datasets.website import (
WebsiteCrawlApi,
WebsiteCrawlStatusApi,
)
from services.website_service import (
WebsiteCrawlApiRequest,
WebsiteCrawlStatusApiRequest,
WebsiteService,
)
def unwrap(func):
"""Recursively unwrap decorated functions."""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask("test_website_crawl")
app.config["TESTING"] = True
return app
@pytest.fixture(autouse=True)
def bypass_auth_and_setup(mocker):
"""Bypass setup/login/account decorators."""
mocker.patch(
"controllers.console.datasets.website.login_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.setup_required",
lambda f: f,
)
mocker.patch(
"controllers.console.datasets.website.account_initialization_required",
lambda f: f,
)
class TestWebsiteCrawlApi:
def test_crawl_success(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {"depth": 1},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
return_value={"job_id": "job-1"},
)
result, status = method(api)
assert status == 200
assert result["job_id"] == "job-1"
def test_crawl_invalid_payload(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "bad-url",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
side_effect=ValueError("invalid payload"),
)
with pytest.raises(WebsiteCrawlError, match="invalid payload"):
method(api)
def test_crawl_service_error(self, app, mocker):
api = WebsiteCrawlApi()
method = unwrap(api.post)
payload = {
"provider": "firecrawl",
"url": "https://example.com",
"options": {},
}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
):
mock_request = Mock(spec=WebsiteCrawlApiRequest)
mocker.patch.object(
WebsiteCrawlApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"crawl_url",
side_effect=Exception("crawl failed"),
)
with pytest.raises(WebsiteCrawlError, match="crawl failed"):
method(api)
class TestWebsiteCrawlStatusApi:
def test_get_status_success(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
return_value={"status": "completed"},
)
result, status = method(api, job_id)
assert status == 200
assert result["status"] == "completed"
def test_get_status_invalid_provider(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
side_effect=ValueError("invalid provider"),
)
with pytest.raises(WebsiteCrawlError, match="invalid provider"):
method(api, job_id)
def test_get_status_service_error(self, app, mocker):
api = WebsiteCrawlStatusApi()
method = unwrap(api.get)
job_id = "job-123"
args = {"provider": "firecrawl"}
with app.test_request_context("/?provider=firecrawl"):
mocker.patch(
"controllers.console.datasets.website.request.args.to_dict",
return_value=args,
)
mock_request = Mock(spec=WebsiteCrawlStatusApiRequest)
mocker.patch.object(
WebsiteCrawlStatusApiRequest,
"from_args",
return_value=mock_request,
)
mocker.patch.object(
WebsiteService,
"get_crawl_status_typed",
side_effect=Exception("status lookup failed"),
)
with pytest.raises(WebsiteCrawlError, match="status lookup failed"):
method(api, job_id)

View File

@@ -0,0 +1,117 @@
from unittest.mock import Mock
import pytest
from controllers.console.datasets.error import PipelineNotFoundError
from controllers.console.datasets.wraps import get_rag_pipeline
from models.dataset import Pipeline
class TestGetRagPipeline:
def test_missing_pipeline_id(self):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
with pytest.raises(ValueError, match="missing pipeline_id"):
dummy_view()
def test_pipeline_not_found(self, mocker):
@get_rag_pipeline
def dummy_view(**kwargs):
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
with pytest.raises(PipelineNotFoundError):
dummy_view(pipeline_id="pipeline-1")
def test_pipeline_found_and_injected(self, mocker):
pipeline = Mock(spec=Pipeline)
pipeline.id = "pipeline-1"
pipeline.tenant_id = "tenant-1"
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result is pipeline
def test_pipeline_id_removed_from_kwargs(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
assert "pipeline_id" not in kwargs
return "ok"
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id="pipeline-1")
assert result == "ok"
def test_pipeline_id_cast_to_string(self, mocker):
pipeline = Mock(spec=Pipeline)
@get_rag_pipeline
def dummy_view(**kwargs):
return kwargs["pipeline"]
mocker.patch(
"controllers.console.datasets.wraps.current_account_with_tenant",
return_value=(Mock(), "tenant-1"),
)
def where_side_effect(*args, **kwargs):
assert args[0].right.value == "123"
return Mock(first=lambda: pipeline)
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
)
result = dummy_view(pipeline_id=123)
assert result is pipeline

View File

@@ -1,13 +1,483 @@
"""Final working unit tests for admin endpoints - tests business logic directly."""
import uuid
from unittest.mock import Mock, patch
from unittest.mock import Mock, PropertyMock, patch
import pytest
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.console.admin import InsertExploreAppPayload
from models.model import App, RecommendedApp
from controllers.console.admin import (
DeleteExploreBannerApi,
InsertExploreAppApi,
InsertExploreAppListApi,
InsertExploreAppPayload,
InsertExploreBannerApi,
InsertExploreBannerPayload,
)
from models.model import App, InstalledApp, RecommendedApp
@pytest.fixture(autouse=True)
def bypass_only_edition_cloud(mocker):
"""
Bypass only_edition_cloud decorator by setting EDITION to "CLOUD".
"""
mocker.patch(
"controllers.console.wraps.dify_config.EDITION",
new="CLOUD",
)
@pytest.fixture
def mock_admin_auth(mocker):
"""
Provide valid admin authentication for controller tests.
"""
mocker.patch(
"controllers.console.admin.dify_config.ADMIN_API_KEY",
"test-admin-key",
)
mocker.patch(
"controllers.console.admin.extract_access_token",
return_value="test-admin-key",
)
@pytest.fixture
def mock_console_payload(mocker):
payload = {
"app_id": str(uuid.uuid4()),
"language": "en-US",
"category": "Productivity",
"position": 1,
}
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value=payload,
)
return payload
@pytest.fixture
def mock_banner_payload(mocker):
mocker.patch(
"flask_restx.namespace.Namespace.payload",
new_callable=PropertyMock,
return_value={
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
},
)
@pytest.fixture
def mock_session_factory(mocker):
mock_session = Mock()
mock_session.execute = Mock()
mock_session.add = Mock()
mock_session.commit = Mock()
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
class TestDeleteExploreBannerApi:
def setup_method(self):
self.api = DeleteExploreBannerApi()
def test_delete_banner_not_found(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.delete(uuid.uuid4())
def test_delete_banner_success(self, mocker, mock_admin_auth):
mock_banner = Mock()
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_banner),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
class TestInsertExploreBannerApi:
def setup_method(self):
self.api = InsertExploreBannerApi()
def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload):
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
def test_banner_payload_valid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "en-US",
}
model = InsertExploreBannerPayload.model_validate(payload)
assert model.language == "en-US"
def test_banner_payload_invalid_language(self):
payload = {
"title": "Test Banner",
"description": "Banner description",
"img-src": "https://example.com/banner.png",
"link": "https://example.com",
"sort": 1,
"category": "homepage",
"language": "invalid-lang",
}
with pytest.raises(ValueError, match="invalid-lang is not a valid language"):
InsertExploreBannerPayload.model_validate(payload)
class TestInsertExploreAppApiDelete:
def setup_method(self):
self.api = InsertExploreAppApi()
def test_delete_when_not_in_explore(self, mocker, mock_admin_auth):
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: s,
__exit__=Mock(return_value=False),
execute=lambda *_: Mock(scalar_one_or_none=lambda: None),
),
)
response, status = self.api.delete(uuid.uuid4())
assert status == 204
assert response["result"] == "success"
def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth):
"""Test deleting an app from explore that has a trial app."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_trial = Mock()
# Mock session context manager and its execute
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
# Set up side effects for execute calls
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: []))),
Mock(scalar_one_or_none=lambda: mock_trial),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert response["result"] == "success"
assert mock_app.is_public is False
def test_delete_with_installed_apps(self, mocker, mock_admin_auth):
"""Test deleting an app that has installed apps in other tenants."""
app_id = uuid.uuid4()
mock_recommended = Mock(spec=RecommendedApp)
mock_recommended.app_id = "app-123"
mock_app = Mock(spec=App)
mock_app.is_public = True
mock_installed_app = Mock(spec=InstalledApp)
# Mock session
mock_session = Mock()
mock_session.execute = Mock()
mock_session.delete = Mock()
mock_session.execute.side_effect = [
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))),
Mock(scalar_one_or_none=lambda: None),
]
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.delete")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.delete(app_id)
assert status == 204
assert mock_session.delete.called
class TestInsertExploreAppListApi:
def setup_method(self):
self.api = InsertExploreAppListApi()
def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload):
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: None),
)
with pytest.raises(NotFound, match="is not found"):
self.api.post()
def test_create_recommended_app(
self,
mocker,
mock_admin_auth,
mock_console_payload,
):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
# db.session.execute → fetch App
mocker.patch(
"controllers.console.admin.db.session.execute",
return_value=Mock(scalar_one_or_none=lambda: mock_app),
)
# session_factory.create_session → recommended_app lookup
mock_session = Mock()
mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None))
mocker.patch(
"controllers.console.admin.session_factory.create_session",
return_value=Mock(
__enter__=lambda s: mock_session,
__exit__=Mock(return_value=False),
),
)
mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 201
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory):
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_site_data_overrides_payload(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
site = Mock()
site.description = "Site Desc"
site.copyright = "Site Copyright"
site.privacy_policy = "Site Privacy"
site.custom_disclaimer = "Site Disclaimer"
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = site
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
commit_spy = mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
commit_spy.assert_called_once()
def test_create_trial_app_when_can_trial_enabled(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 5
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.tenant_id = "tenant"
mock_app.is_public = False
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: None),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
self.api.post()
assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list)
def test_update_recommended_app_with_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app when trial is enabled."""
mock_console_payload["can_trial"] = True
mock_console_payload["trial_limit"] = 10
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_app.tenant_id = "tenant-123"
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
Mock(scalar_one_or_none=lambda: None),
],
)
add_spy = mocker.patch("controllers.console.admin.db.session.add")
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
def test_update_recommended_app_without_trial(
self,
mocker,
mock_admin_auth,
mock_console_payload,
mock_session_factory,
):
"""Test updating a recommended app without trial enabled."""
mock_app = Mock(spec=App)
mock_app.id = "app-id"
mock_app.site = None
mock_app.is_public = False
mock_recommended = Mock(spec=RecommendedApp)
mocker.patch(
"controllers.console.admin.db.session.execute",
side_effect=[
Mock(scalar_one_or_none=lambda: mock_app),
Mock(scalar_one_or_none=lambda: mock_recommended),
],
)
mocker.patch("controllers.console.admin.db.session.commit")
response, status = self.api.post()
assert status == 200
assert response["result"] == "success"
assert mock_app.is_public is True
class TestInsertExploreAppPayload:

View File

@@ -0,0 +1,138 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.console.apikey import (
BaseApiKeyListResource,
BaseApiKeyResource,
_get_resource,
)
@pytest.fixture
def tenant_context_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = True
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def tenant_context_non_admin():
with patch("controllers.console.apikey.current_account_with_tenant") as mock:
user = MagicMock()
user.is_admin_or_owner = False
mock.return_value = (user, "tenant-123")
yield mock
@pytest.fixture
def db_mock():
with patch("controllers.console.apikey.db") as mock_db:
mock_db.session = MagicMock()
yield mock_db
@pytest.fixture(autouse=True)
def bypass_permissions():
with patch(
"controllers.console.apikey.edit_permission_required",
lambda f: f,
):
yield
class DummyApiKeyListResource(BaseApiKeyListResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
token_prefix = "app-"
class DummyApiKeyResource(BaseApiKeyResource):
resource_type = "app"
resource_model = MagicMock()
resource_id_field = "app_id"
class TestGetResource:
def test_get_resource_success(self):
fake_resource = MagicMock()
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = fake_resource
result = _get_resource("rid", "tid", MagicMock)
assert result == fake_resource
def test_get_resource_not_found(self):
with (
patch("controllers.console.apikey.select") as mock_select,
patch("controllers.console.apikey.Session") as mock_session,
patch("controllers.console.apikey.db") as mock_db,
patch("controllers.console.apikey.flask_restx.abort") as abort,
):
mock_db.engine = MagicMock()
mock_select.return_value.filter_by.return_value = MagicMock()
session = mock_session.return_value.__enter__.return_value
session.execute.return_value.scalar_one_or_none.return_value = None
_get_resource("rid", "tid", MagicMock)
abort.assert_called_once()
class TestBaseApiKeyListResource:
def test_get_apikeys_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyListResource()
with patch("controllers.console.apikey._get_resource"):
db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()]
result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id")
assert "items" in result
class TestBaseApiKeyResource:
def test_delete_forbidden(self, tenant_context_non_admin, db_mock):
resource = DummyApiKeyResource()
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Forbidden):
DummyApiKeyResource.delete(resource, "rid", "kid")
def test_delete_key_not_found(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = None
with patch("controllers.console.apikey._get_resource"):
with pytest.raises(Exception) as exc_info:
DummyApiKeyResource.delete(resource, "rid", "kid")
# flask_restx.abort raises HTTPException with message in data attribute
assert exc_info.value.data["message"] == "API key not found"
def test_delete_success(self, tenant_context_admin, db_mock):
resource = DummyApiKeyResource()
db_mock.session.query.return_value.where.return_value.first.return_value = MagicMock()
with (
patch("controllers.console.apikey._get_resource"),
patch("controllers.console.apikey.ApiTokenCache.delete"),
):
result, status = DummyApiKeyResource.delete(resource, "rid", "kid")
assert status == 204
assert result == {"result": "success"}
db_mock.session.commit.assert_called_once()

View File

@@ -1,46 +0,0 @@
import builtins
from unittest.mock import patch
import pytest
from flask import Flask
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
app.secret_key = "test-secret-key"
return app
def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.delenv("INIT_PASSWORD", raising=False)
with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"):
client = app.test_client()
response = client.get("/console/api/init")
assert response.status_code == 200
assert response.get_json() == {"status": "finished"}
def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch):
ext_fastopenapi.init_app(app)
monkeypatch.setenv("INIT_PASSWORD", "test-init-password")
with (
patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0),
):
client = app.test_client()
response = client.post("/console/api/init", json={"password": "test-init-password"})
assert response.status_code == 201
assert response.get_json() == {"result": "success"}

View File

@@ -1,286 +0,0 @@
"""Tests for remote file upload API endpoints using Flask-RESTX."""
import contextlib
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock, patch
import httpx
import pytest
from flask import Flask, g
@pytest.fixture
def app() -> Flask:
"""Create Flask app for testing."""
app = Flask(__name__)
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret-key"
return app
@pytest.fixture
def client(app):
"""Create test client with console blueprint registered."""
from controllers.console import bp
app.register_blueprint(bp)
return app.test_client()
@pytest.fixture
def mock_account():
"""Create a mock account for testing."""
from models import Account
account = Mock(spec=Account)
account.id = "test-account-id"
account.current_tenant_id = "test-tenant-id"
return account
@pytest.fixture
def auth_ctx(app, mock_account):
"""Context manager to set auth/tenant context in flask.g for a request."""
@contextlib.contextmanager
def _ctx():
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
yield
return _ctx
class TestGetRemoteFileInfo:
"""Test GET /console/api/remote-files/<path:url> endpoint."""
def test_get_remote_file_info_success(self, app, client, mock_account):
"""Test successful retrieval of remote file info."""
response = httpx.Response(
200,
request=httpx.Request("HEAD", "http://example.com/file.txt"),
headers={"Content-Type": "text/plain", "Content-Length": "1024"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "text/plain"
assert data["file_length"] == 1024
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
"""Test fallback to GET when HEAD returns non-200 status."""
head_response = httpx.Response(
404,
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
)
get_response = httpx.Response(
200,
request=httpx.Request("GET", "http://example.com/file.pdf"),
headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "application/pdf"
assert data["file_length"] == 2048
class TestRemoteFileUpload:
"""Test POST /console/api/remote-files/upload endpoint."""
@pytest.mark.parametrize(
("head_status", "use_get"),
[
(200, False), # HEAD succeeds
(405, True), # HEAD fails -> fallback GET
],
)
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
url = "http://example.com/file.pdf"
head_resp = httpx.Response(
head_status,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
)
get_resp = httpx.Response(
200,
request=httpx.Request("GET", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
content=b"file content",
)
file_info = SimpleNamespace(
extension="pdf",
size=1024,
filename="file.pdf",
mimetype="application/pdf",
)
uploaded_file = SimpleNamespace(
id="uploaded-file-id",
name="file.pdf",
size=1024,
extension="pdf",
mime_type="application/pdf",
created_by="test-account-id",
created_at=datetime(2024, 1, 1, 12, 0, 0),
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=True,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("controllers.console.remote_files.FileService") as mock_file_service,
patch(
"controllers.console.remote_files.file_helpers.get_signed_file_url",
return_value="http://example.com/signed-url",
),
patch("libs.login.check_csrf_token", return_value=None),
):
mock_file_service.return_value.upload_file.return_value = uploaded_file
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == 201
p_head.assert_called_once()
# GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
p_get.assert_called_once()
mock_file_service.return_value.upload_file.assert_called_once()
data = resp.get_json()
assert data["id"] == "uploaded-file-id"
assert data["name"] == "file.pdf"
assert data["size"] == 1024
assert data["extension"] == "pdf"
assert data["url"] == "http://example.com/signed-url"
assert data["mime_type"] == "application/pdf"
assert data["created_by"] == "test-account-id"
@pytest.mark.parametrize(
("size_ok", "raises", "expected_status", "expected_msg"),
[
# When size check fails in controller, API returns 413 with message "File size exceeded..."
(False, None, 413, "file size exceeded"),
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
(True, "unsupported", 415, "file type not allowed"),
],
)
def test_upload_remote_file_errors(
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
):
url = "http://example.com/x.pdf"
head_resp = httpx.Response(
200,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
)
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=size_ok,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("libs.login.check_csrf_token", return_value=None),
):
if raises == "unsupported":
from services.errors.file import UnsupportedFileTypeError
with patch("controllers.console.remote_files.FileService") as mock_file_service:
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
else:
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == expected_status
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert expected_msg in msg.lower()
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
"""Test upload when fetching of remote file fails."""
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch(
"controllers.console.remote_files.ssrf_proxy.head",
side_effect=httpx.RequestError("Connection failed"),
),
patch("libs.login.check_csrf_token", return_value=None),
):
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": "http://unreachable.com/file.pdf"},
)
assert resp.status_code == 400
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert "failed to fetch" in msg.lower()

View File

@@ -0,0 +1,81 @@
from werkzeug.exceptions import Unauthorized
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestFeatureApi:
def test_get_tenant_features_success(self, mocker):
from controllers.console.feature import FeatureApi
mocker.patch(
"controllers.console.feature.current_account_with_tenant",
return_value=("account_id", "tenant_123"),
)
mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = {
"features": {"feature_a": True}
}
api = FeatureApi()
raw_get = unwrap(FeatureApi.get)
result = raw_get(api)
assert result == {"features": {"feature_a": True}}
class TestSystemFeatureApi:
def test_get_system_features_authenticated(self, mocker):
"""
current_user.is_authenticated == True
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
fake_user.is_authenticated = True
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": True}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": True}}
def test_get_system_features_unauthenticated(self, mocker):
"""
current_user.is_authenticated raises Unauthorized
"""
from controllers.console.feature import SystemFeatureApi
fake_user = mocker.Mock()
type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized())
mocker.patch(
"controllers.console.feature.current_user",
fake_user,
)
mocker.patch(
"controllers.console.feature.FeatureService.get_system_features"
).return_value.model_dump.return_value = {"features": {"sys_feature": False}}
api = SystemFeatureApi()
result = api.get()
assert result == {"features": {"sys_feature": False}}

View File

@@ -0,0 +1,300 @@
import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from constants import DOCUMENT_EXTENSIONS
from controllers.common.errors import (
BlockedFileExtensionError,
FilenameNotExistsError,
FileTooLargeError,
NoFileUploadedError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.console.files import (
FileApi,
FilePreviewApi,
FileSupportTypeApi,
)
def unwrap(func):
"""
Recursively unwrap decorated functions.
"""
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
app = Flask(__name__)
app.testing = True
return app
@pytest.fixture(autouse=True)
def mock_decorators():
"""
Make decorators no-ops so logic is directly testable
"""
with (
patch("controllers.console.files.setup_required", new=lambda f: f),
patch("controllers.console.files.login_required", new=lambda f: f),
patch("controllers.console.files.account_initialization_required", new=lambda f: f),
patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f),
):
yield
@pytest.fixture
def mock_current_user():
user = MagicMock()
user.is_dataset_editor = True
return user
@pytest.fixture
def mock_account_context(mock_current_user):
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
yield
@pytest.fixture
def mock_db():
with patch("controllers.console.files.db") as db_mock:
db_mock.engine = MagicMock()
yield db_mock
@pytest.fixture
def mock_file_service(mock_db):
with patch("controllers.console.files.FileService") as fs:
instance = fs.return_value
yield instance
class TestFileApiGet:
def test_get_upload_config(self, app):
api = FileApi()
get_method = unwrap(api.get)
with app.test_request_context():
data, status = get_method(api)
assert status == 200
assert "file_size_limit" in data
assert "batch_count_limit" in data
class TestFileApiPost:
def test_no_file_uploaded(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api)
def test_too_many_files(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST"):
from unittest.mock import MagicMock, patch
with patch("controllers.console.files.request") as mock_request:
mock_request.files = MagicMock()
mock_request.files.__len__.return_value = 2
mock_request.files.__contains__.return_value = True
mock_request.form = MagicMock()
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api)
def test_filename_missing(self, app, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), ""),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api)
def test_dataset_upload_without_permission(self, app, mock_current_user):
mock_current_user.is_dataset_editor = False
with patch(
"controllers.console.files.current_account_with_tenant",
return_value=(mock_current_user, None),
):
api = FileApi()
post_method = unwrap(api.post)
data = {
"file": (io.BytesIO(b"abc"), "test.txt"),
"source": "datasets",
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(Forbidden):
post_method(api)
def test_successful_upload(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
mock_file = MagicMock()
mock_file.id = "file-id-123"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 1024
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-123"
mock_file.created_at = 1234567890
mock_file.preview_url = "http://example.com/preview/file-id-123"
mock_file.source_url = "http://example.com/source/file-id-123"
mock_file.original_url = None
mock_file.user_id = "user-123"
mock_file.tenant_id = "tenant-123"
mock_file.conversation_id = None
mock_file.file_key = "file-key-123"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"hello"), "test.txt"),
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
# Create a properly structured mock file object
mock_file = MagicMock()
mock_file.id = "file-id-456"
mock_file.filename = "test.txt"
mock_file.name = "test.txt"
mock_file.size = 512
mock_file.extension = "txt"
mock_file.mime_type = "text/plain"
mock_file.created_by = "user-456"
mock_file.created_at = 1234567890
mock_file.preview_url = None
mock_file.source_url = None
mock_file.original_url = None
mock_file.user_id = "user-456"
mock_file.tenant_id = "tenant-456"
mock_file.conversation_id = None
mock_file.file_key = "file-key-456"
mock_file_service.upload_file.return_value = mock_file
data = {
"file": (io.BytesIO(b"content"), "test.txt"),
"source": "invalid_source", # Should be normalized to None
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api)
assert status == 201
assert response["id"] == "file-id-456"
# Verify that FileService was called with source=None
mock_file_service.upload_file.assert_called_once()
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
error = ServiceFileTooLargeError("File is too large")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x" * 1000000), "big.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api)
def test_unsupported_file_type(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
error = ServiceUnsupportedFileTypeError()
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "bad.exe"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api)
def test_blocked_extension(self, app, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError
error = ServiceBlockedFileExtensionError("File extension is blocked")
mock_file_service.upload_file.side_effect = error
data = {
"file": (io.BytesIO(b"x"), "blocked.txt"),
}
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api)
class TestFilePreviewApi:
def test_get_preview(self, app, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, "1234")
assert result == {"content": "preview text"}
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
api = FileSupportTypeApi()
get_method = unwrap(api.get)
with app.test_request_context():
result = get_method(api)
assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)}

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
import json
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Response
from controllers.console.human_input_form import (
ConsoleHumanInputFormApi,
ConsoleWorkflowEventsApi,
DifyAPIRepositoryFactory,
WorkflowResponseConverter,
_jsonify_form_definition,
)
from controllers.web.error import NotFoundError
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_jsonify_form_definition() -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration)
response = _jsonify_form_definition(form)
assert isinstance(response, Response)
payload = json.loads(response.get_data(as_text=True))
assert payload["expiration_time"] == int(expiration.timestamp())
def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1")
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2"))
with pytest.raises(NotFoundError):
ConsoleHumanInputFormApi._ensure_console_access(form)
def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]})
form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
response = handler(api, form_token="token")
payload = json.loads(response.get_data(as_text=True))
assert payload["fields"] == ["a"]
def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_definition_by_token_for_console(self, _token):
return None
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1"))
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(NotFoundError):
handler(api, form_token="token")
def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
submit_mock = Mock()
form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE)
class _ServiceStub:
def __init__(self, *_args, **_kwargs):
pass
def get_form_by_token(self, _token):
return form
def submit_form_by_token(self, **kwargs):
submit_mock(**kwargs)
monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "tenant-1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response = handler(api, form_token="token")
assert response.get_json() == {}
submit_mock.assert_called_once()
def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return None
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-2",
tenant_id="t1",
)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="u1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
handler(api, workflow_run_id="run-1")
def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-1",
tenant_id="t1",
app_id="app-1",
finished_at=datetime(2024, 1, 1, tzinfo=UTC),
)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
class _RepoStub:
def get_workflow_run_by_id_and_tenant_id(self, **_kwargs):
return workflow_run
response_obj = SimpleNamespace(
event=SimpleNamespace(value="finished"),
model_dump=lambda mode="json": {"status": "done"},
)
monkeypatch.setattr(
DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_args, **_kwargs: _RepoStub(),
)
monkeypatch.setattr(
"controllers.console.human_input_form._retrieve_app_for_workflow_run",
lambda *_args, **_kwargs: app_model,
)
monkeypatch.setattr(
WorkflowResponseConverter,
"workflow_run_result_to_finish_response",
lambda **_kwargs: response_obj,
)
monkeypatch.setattr(
"controllers.console.human_input_form.current_account_with_tenant",
lambda: (SimpleNamespace(id="user-1"), "t1"),
)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
response = handler(api, workflow_run_id="run-1")
assert response.mimetype == "text/event-stream"
assert "data" in response.get_data(as_text=True)

View File

@@ -0,0 +1,108 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from controllers.console import init_validate
from controllers.console.error import AlreadySetupError, InitValidateFailedError
class _SessionStub:
def __init__(self, has_setup: bool):
self._has_setup = has_setup
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def execute(self, *_args, **_kwargs):
return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None)
def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True)
result = init_validate.get_init_status()
assert result.status == "finished"
def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False)
result = init_validate.get_init_status()
assert result.status == "not_started"
def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1)
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(AlreadySetupError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw"))
def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
with pytest.raises(InitValidateFailedError):
init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong"))
assert init_validate.session.get("is_init_validated") is False
def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0)
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="POST"):
result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected"))
assert result.result == "success"
assert init_validate.session.get("is_init_validated") is True
def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD")
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session["is_init_validated"] = True
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is True
def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED")
monkeypatch.setenv("INIT_PASSWORD", "expected")
monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False))
monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object()))
app.secret_key = "test-secret"
with app.test_request_context("/console/api/init", method="GET"):
init_validate.session.pop("is_init_validated", None)
assert init_validate.get_init_validate_status() is False

View File

@@ -0,0 +1,281 @@
from __future__ import annotations
import urllib.parse
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import httpx
import pytest
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class _FakeResponse:
def __init__(
self,
*,
status_code: int = 200,
headers: dict[str, str] | None = None,
method: str = "GET",
content: bytes = b"",
text: str = "",
error: Exception | None = None,
) -> None:
self.status_code = status_code
self.headers = headers or {}
self.request = SimpleNamespace(method=method)
self.content = content
self.text = text
self._error = error
def raise_for_status(self) -> None:
if self._error:
raise self._error
def _mock_upload_dependencies(
monkeypatch: pytest.MonkeyPatch,
*,
file_size_within_limit: bool = True,
):
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
mimetype="text/plain",
size=3,
)
monkeypatch.setattr(
remote_files_module.helpers,
"guess_file_info_from_response",
MagicMock(return_value=file_info),
)
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None))
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
"get_signed_file_url",
lambda upload_file_id: f"https://signed.example/{upload_file_id}",
)
return file_service_cls
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
head_resp = _FakeResponse(
status_code=200,
headers={"Content-Type": "text/plain", "Content-Length": "128"},
method="HEAD",
)
head_mock = MagicMock(return_value=head_resp)
get_mock = MagicMock()
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "text/plain", "file_length": 128}
head_mock.assert_called_once_with(decoded_url)
get_mock.assert_not_called()
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503)))
get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET"))
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
with app.test_request_context(method="GET"):
payload = handler(api, url=encoded_url)
assert payload == {"file_type": "application/octet-stream", "file_length": 0}
get_mock.assert_called_once_with(decoded_url, timeout=3)
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/report.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404)))
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
get_mock = MagicMock(return_value=get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-1",
name="report.txt",
size=16,
extension=".txt",
mime_type="text/plain",
created_by="u1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-1"
assert payload["url"] == "https://signed.example/file-1"
get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True)
file_service_cls.return_value.upload_file.assert_called_once_with(
filename="report.txt",
content=b"fallback-content",
mimetype="text/plain",
user=SimpleNamespace(id="u1"),
source_url=url,
)
def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/photo.jpg"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")),
)
extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content")
get_mock = MagicMock(return_value=extra_get_resp)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock)
file_service_cls = _mock_upload_dependencies(monkeypatch)
upload_file = SimpleNamespace(
id="file-2",
name="photo.jpg",
size=18,
extension=".jpg",
mime_type="image/jpeg",
created_by="u1",
created_at=datetime(2024, 1, 2, tzinfo=UTC),
)
file_service_cls.return_value.upload_file.return_value = upload_file
with app.test_request_context(method="POST", json={"url": url}):
payload, status = handler(api)
assert status == 201
assert payload["id"] == "file-2"
get_mock.assert_called_once_with(url)
assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content"
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500)))
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"get",
MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api)
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/fail.txt"
request = httpx.Request("HEAD", url)
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(side_effect=httpx.RequestError("network down", request=request)),
)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api)
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
_mock_upload_dependencies(monkeypatch, file_size_within_limit=False)
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError):
handler(api)
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/large.bin"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded")
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(FileTooLargeError, match="size exceeded"):
handler(api)
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
url = "https://example.com/file.exe"
monkeypatch.setattr(
remote_files_module.ssrf_proxy,
"head",
MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")),
)
monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock())
file_service_cls = _mock_upload_dependencies(monkeypatch)
file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError()
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(UnsupportedFileTypeError):
handler(api)

View File

@@ -0,0 +1,49 @@
from unittest.mock import patch
import controllers.console.spec as spec_module
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestSpecSchemaDefinitionsApi:
def test_get_success(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
schema_definitions = [{"type": "string"}]
with patch.object(
spec_module,
"SchemaManager",
) as schema_manager_cls:
schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions
resp, status = method(api)
assert status == 200
assert resp == schema_definitions
def test_get_exception_returns_empty_list(self):
api = spec_module.SpecSchemaDefinitionsApi()
method = unwrap(api.get)
with (
patch.object(
spec_module,
"SchemaManager",
side_effect=Exception("boom"),
),
patch.object(
spec_module.logger,
"exception",
) as log_exception,
):
resp, status = method(api)
assert status == 200
assert resp == []
log_exception.assert_called_once()

View File

@@ -0,0 +1,162 @@
from unittest.mock import MagicMock, patch
import controllers.console.version as version_module
class TestHasNewVersion:
def test_has_new_version_true(self):
result = version_module._has_new_version(
latest_version="1.2.0",
current_version="1.1.0",
)
assert result is True
def test_has_new_version_false(self):
result = version_module._has_new_version(
latest_version="1.0.0",
current_version="1.1.0",
)
assert result is False
def test_has_new_version_invalid_version(self):
with patch.object(version_module.logger, "warning") as log_warning:
result = version_module._has_new_version(
latest_version="invalid",
current_version="1.0.0",
)
assert result is False
log_warning.assert_called_once()
class TestCheckVersionUpdate:
def test_no_check_update_url(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"",
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
True,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
False,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
assert result.can_auto_update is False
assert result.features.can_replace_logo is True
assert result.features.model_load_balancing_enabled is False
def test_http_error_fallback(self):
query = version_module.VersionQuery(current_version="1.0.0")
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
side_effect=Exception("boom"),
),
patch.object(
version_module.logger,
"warning",
) as log_warning,
):
result = version_module.check_version_update(query)
assert result.version == "1.0.0"
log_warning.assert_called_once()
def test_new_version_available(self):
query = version_module.VersionQuery(current_version="1.0.0")
response = MagicMock()
response.json.return_value = {
"version": "1.2.0",
"releaseDate": "2024-01-01",
"releaseNotes": "New features",
"canAutoUpdate": True,
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.0.0",
),
patch.object(
version_module.dify_config,
"CAN_REPLACE_LOGO",
False,
),
patch.object(
version_module.dify_config,
"MODEL_LB_ENABLED",
True,
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.release_date == "2024-01-01"
assert result.release_notes == "New features"
assert result.can_auto_update is True
def test_no_new_version(self):
query = version_module.VersionQuery(current_version="1.2.0")
response = MagicMock()
response.json.return_value = {
"version": "1.1.0",
}
with (
patch.object(
version_module.dify_config,
"CHECK_UPDATE_URL",
"http://example.com",
),
patch.object(
version_module.httpx,
"get",
return_value=response,
),
patch.object(
version_module.dify_config.project,
"version",
"1.2.0",
),
):
result = version_module.check_version_update(query)
assert result.version == "1.2.0"
assert result.can_auto_update is False

View File

@@ -0,0 +1,341 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailAlreadyInUseError,
EmailCodeError,
)
from controllers.console.error import AccountInFreezeError
from controllers.console.workspace.account import (
AccountAvatarApi,
AccountDeleteApi,
AccountDeleteVerifyApi,
AccountInitApi,
AccountIntegrateApi,
AccountInterfaceLanguageApi,
AccountInterfaceThemeApi,
AccountNameApi,
AccountPasswordApi,
AccountProfileApi,
AccountTimezoneApi,
ChangeEmailCheckApi,
ChangeEmailResetApi,
CheckEmailUnique,
)
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
InvalidAccountDeletionCodeError,
)
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAccountInitApi:
def test_init_success(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="inactive")
payload = {
"interface_language": "en-US",
"timezone": "UTC",
"invitation_code": "code123",
}
with (
app.test_request_context("/account/init", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.query") as query_mock,
):
query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused")
resp = method(api)
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
api = AccountInitApi()
method = unwrap(api.post)
account = MagicMock(status="active")
with (
app.test_request_context("/account/init"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
):
with pytest.raises(AccountAlreadyInitedError):
method(api)
class TestAccountProfileApi:
def test_get_profile_success(self, app):
api = AccountProfileApi()
method = unwrap(api.get)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/account/profile"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountUpdateApis:
@pytest.mark.parametrize(
("api_cls", "payload"),
[
(AccountNameApi, {"name": "test"}),
(AccountAvatarApi, {"avatar": "img.png"}),
(AccountInterfaceLanguageApi, {"interface_language": "en-US"}),
(AccountInterfaceThemeApi, {"interface_theme": "dark"}),
(AccountTimezoneApi, {"timezone": "UTC"}),
],
)
def test_update_success(self, app, api_cls, payload):
api = api_cls()
method = unwrap(api.post)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
):
result = method(api)
assert result["id"] == "u1"
class TestAccountPasswordApi:
def test_password_success(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "old",
"new_password": "new123",
"repeat_new_password": "new123",
}
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
):
result = method(api)
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
api = AccountPasswordApi()
method = unwrap(api.post)
payload = {
"password": "bad",
"new_password": "new123",
"repeat_new_password": "new123",
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.update_account_password",
side_effect=ServicePwdError(),
),
):
with pytest.raises(CurrentPasswordIncorrectError):
method(api)
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
api = AccountIntegrateApi()
method = unwrap(api.get)
account = MagicMock(id="acc1")
with (
app.test_request_context("/"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
):
scalars_mock.return_value.all.return_value = []
result = method(api)
assert "data" in result
assert len(result["data"]) == 2
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
return_value=("token", "1234"),
),
patch(
"controllers.console.workspace.account.AccountService.send_account_deletion_verification_email",
return_value=None,
),
):
result = method(api)
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
api = AccountDeleteApi()
method = unwrap(api.post)
payload = {"token": "t", "code": "x"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
return_value=False,
),
):
with pytest.raises(InvalidAccountDeletionCodeError):
method(api)
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
payload = {"email": "a@test.com", "code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.account.AccountService.get_change_email_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
api = ChangeEmailResetApi()
method = unwrap(api.post)
payload = {"new_email": "x@test.com", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
):
with pytest.raises(EmailAlreadyInUseError):
method(api)
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "ok@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False),
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True),
):
result = method(api)
assert result["result"] == "success"
def test_email_in_freeze(self, app):
api = CheckEmailUnique()
method = unwrap(api.post)
payload = {"email": "x@test.com"}
with (
app.test_request_context("/", json=payload),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True),
):
with pytest.raises(AccountInFreezeError):
method(api)

View File

@@ -0,0 +1,139 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
AgentProviderApi,
AgentProviderListApi,
)
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestAgentProviderListApi:
def test_get_success(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
providers = [{"name": "openai"}, {"name": "anthropic"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=providers,
),
):
result = method(api)
assert result == providers
def test_get_empty_list(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.list_agent_providers",
return_value=[],
),
):
result = method(api)
assert result == []
def test_get_account_not_found(self, app):
api = AgentProviderListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api)
class TestAgentProviderApi:
def test_get_success(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "openai"
provider_data = {"name": "openai", "models": ["gpt-4"]}
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=provider_data,
),
):
result = method(api, provider_name)
assert result == provider_data
def test_get_provider_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
user = MagicMock(id="user1")
tenant_id = "tenant1"
provider_name = "unknown"
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.agent_providers.AgentService.get_agent_provider",
return_value=None,
),
):
result = method(api, provider_name)
assert result is None
def test_get_account_not_found(self, app):
api = AgentProviderApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.agent_providers.current_account_with_tenant",
side_effect=AccountNotFound(),
),
):
with pytest.raises(AccountNotFound):
method(api, "openai")

View File

@@ -0,0 +1,305 @@
from unittest.mock import MagicMock, patch
import pytest
from controllers.console.workspace.endpoint import (
EndpointCreateApi,
EndpointDeleteApi,
EndpointDisableApi,
EndpointEnableApi,
EndpointListApi,
EndpointListForSinglePluginApi,
EndpointUpdateApi,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user_and_tenant():
return MagicMock(id="u1"), "t1"
@pytest.fixture
def patch_current_account(user_and_tenant):
with patch(
"controllers.console.workspace.endpoint.current_account_with_tenant",
return_value=user_and_tenant,
):
yield
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCreateApi:
def test_create_success(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {"a": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_create_permission_denied(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
"name": "endpoint",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.endpoint.EndpointService.create_endpoint",
side_effect=PluginPermissionDeniedError("denied"),
),
):
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
api = EndpointCreateApi()
method = unwrap(api.post)
payload = {
"plugin_unique_identifier": "p1",
"name": "",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
):
result = method(api)
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
api = EndpointListApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=0&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
patch(
"controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin",
return_value=[{"id": "e1"}],
),
):
result = method(api)
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDeleteApi:
def test_delete_success(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_delete_invalid_payload(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
api = EndpointDeleteApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointUpdateApi:
def test_update_success(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "new-name",
"settings": {"x": 1},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_update_validation_error(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
with (
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
api = EndpointUpdateApi()
method = unwrap(api.post)
payload = {
"endpoint_id": "e1",
"name": "n",
"settings": {},
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_enable_invalid_payload(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
api = EndpointEnableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
):
result = method(api)
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
payload = {"endpoint_id": "e1"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
):
result = method(api)
assert result["success"] is True
def test_disable_invalid_payload(self, app):
api = EndpointDisableApi()
method = unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)

View File

@@ -0,0 +1,607 @@
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import HTTPException
import services
from controllers.console.auth.error import (
CannotTransferOwnerToSelfError,
EmailCodeError,
InvalidEmailError,
InvalidTokenError,
MemberNotInTenantError,
NotOwnerError,
OwnerTransferLimitError,
)
from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.workspace.members import (
DatasetOperatorMemberListApi,
MemberCancelInviteApi,
MemberInviteEmailApi,
MemberListApi,
MemberUpdateRoleApi,
OwnerTransfer,
OwnerTransferCheckApi,
SendOwnerTransferEmailApi,
)
from services.errors.account import AccountAlreadyInTenantError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMemberListApi:
def test_get_success(self, app):
api = MemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "m1"
member.name = "Member"
member.email = "member@test.com"
member.avatar = "avatar.png"
member.role = "admin"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = MemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
"language": "en-US",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = False
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
):
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=AccountAlreadyInTenantError(),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, status = method(api)
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
payload = {
"emails": ["a@test.com"],
"role": "owner",
}
with app.test_request_context("/", json=payload):
result, status = method(api)
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
api = MemberInviteEmailApi()
method = unwrap(api.post)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
features = MagicMock()
features.workspace_members.is_available.return_value = True
payload = {
"emails": ["a@test.com"],
"role": "normal",
}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features),
patch(
"controllers.console.workspace.members.RegisterService.invite_new_member",
side_effect=Exception("boom"),
),
patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"),
):
result, _ = method(api)
assert result["invitation_results"][0]["status"] == "failed"
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
):
q.return_value.where.return_value.first.return_value = None
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.CannotOperateSelfError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 400
def test_cancel_no_permission(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.NoPermissionError("x"),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
tenant = MagicMock(id="t1")
user = MagicMock(current_tenant=tenant)
member = MagicMock()
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.query") as q,
patch(
"controllers.console.workspace.members.TenantService.remove_member_from_tenant",
side_effect=services.errors.account.MemberNotInTenantError(),
),
):
q.return_value.where.return_value.first.return_value = member
result, status = method(api, member.id)
assert status == 404
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.update_member_role"),
):
result = method(api, "id")
if isinstance(result, tuple):
result = result[0]
assert result["result"] == "success"
def test_update_invalid_role(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "invalid-role"}
with app.test_request_context("/", json=payload):
result, status = method(api, "id")
assert status == 400
def test_update_member_not_found(self, app):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
payload = {"role": "normal"}
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.members.current_account_with_tenant",
return_value=(MagicMock(current_tenant=MagicMock()), "t1"),
),
patch("controllers.console.workspace.members.db.session.get", return_value=None),
):
with pytest.raises(HTTPException):
method(api, "id")
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
member = MagicMock()
member.id = "op1"
member.name = "Operator"
member.email = "operator@test.com"
member.avatar = "avatar.png"
member.role = "operator"
member.status = "active"
members = [member]
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members
),
):
result, status = method(api)
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
user = MagicMock(current_tenant=None)
with (
app.test_request_context("/"),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
):
with pytest.raises(ValueError):
method(api)
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock(name="ws")
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token"
),
):
result = method(api)
assert result["result"] == "success"
def test_send_ip_limit(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True),
):
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(current_tenant=tenant)
with (
app.test_request_context("/", json={}),
patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"),
patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False),
):
with pytest.raises(NotOwnerError):
method(api)
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com", "code": "y"},
),
):
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=True,
),
):
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(email="a@test.com", current_tenant=tenant)
payload = {"code": "x", "token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit",
return_value=False,
),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "b@test.com", "code": "x"},
),
):
with pytest.raises(InvalidEmailError):
method(api)
class TestOwnerTransferApi:
def test_transfer_self(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
):
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None),
):
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
api = OwnerTransfer()
method = unwrap(api.post)
tenant = MagicMock()
user = MagicMock(id="1", email="a@test.com", current_tenant=tenant)
member = MagicMock()
payload = {"token": "t"}
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True),
patch(
"controllers.console.workspace.members.AccountService.get_owner_transfer_data",
return_value={"email": "a@test.com"},
),
patch("controllers.console.workspace.members.db.session.get", return_value=member),
patch("controllers.console.workspace.members.TenantService.is_member", return_value=False),
):
with pytest.raises(MemberNotInTenantError):
method(api, "2")

Some files were not shown because too many files have changed in this diff Show More