mirror of
https://github.com/langgenius/dify.git
synced 2026-03-31 21:16:50 +00:00
Compare commits
35 Commits
dependabot
...
refactor/i
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee98705ecc | ||
|
|
1063e021f2 | ||
|
|
303f548408 | ||
|
|
cc68f0e640 | ||
|
|
9b7b432e08 | ||
|
|
88863609e9 | ||
|
|
adc6c6c13b | ||
|
|
2de818530b | ||
|
|
7e4754392d | ||
|
|
01c857a67a | ||
|
|
2c2cc72150 | ||
|
|
f7b78b08fd | ||
|
|
f0e6f11c1c | ||
|
|
a19243068b | ||
|
|
323c51e095 | ||
|
|
bbc3f90928 | ||
|
|
1344c3b280 | ||
|
|
5897b28355 | ||
|
|
15aa8071f8 | ||
|
|
097095a69b | ||
|
|
daebe26089 | ||
|
|
c58170f5b8 | ||
|
|
3a7885819d | ||
|
|
5fc4dfaf7b | ||
|
|
953bcc33b1 | ||
|
|
bc14ad6a8f | ||
|
|
cc89b57c1f | ||
|
|
623c8ae803 | ||
|
|
dede190be2 | ||
|
|
a1513f06c3 | ||
|
|
3c7180bfd5 | ||
|
|
51f6ca2bed | ||
|
|
ae9a16a397 | ||
|
|
52a4bea88f | ||
|
|
1aaba80211 |
1
.github/actions/setup-web/action.yml
vendored
1
.github/actions/setup-web/action.yml
vendored
@@ -6,7 +6,6 @@ runs:
|
||||
- name: Setup Vite+
|
||||
uses: voidzero-dev/setup-vp@20553a7a7429c429a74894104a2835d7fed28a72 # v1.3.0
|
||||
with:
|
||||
working-directory: web
|
||||
node-version-file: .nvmrc
|
||||
cache: true
|
||||
run-install: true
|
||||
|
||||
4
.github/workflows/autofix.yml
vendored
4
.github/workflows/autofix.yml
vendored
@@ -39,6 +39,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
- name: Check api inputs
|
||||
if: github.event_name != 'merge_group'
|
||||
id: api-changes
|
||||
|
||||
30
.github/workflows/build-push.yml
vendored
30
.github/workflows/build-push.yml
vendored
@@ -24,27 +24,39 @@ env:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.platform == 'linux/arm64' && 'arm64_runner' || 'ubuntu-latest' }}
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
if: github.repository == 'langgenius/dify'
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "build-api-amd64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-api-arm64"
|
||||
image_name_env: "DIFY_API_IMAGE_NAME"
|
||||
context: "api"
|
||||
artifact_context: "api"
|
||||
build_context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
- service_name: "build-web-amd64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/amd64
|
||||
runs_on: ubuntu-latest
|
||||
- service_name: "build-web-arm64"
|
||||
image_name_env: "DIFY_WEB_IMAGE_NAME"
|
||||
context: "web"
|
||||
artifact_context: "web"
|
||||
build_context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
platform: linux/arm64
|
||||
runs_on: ubuntu-24.04-arm
|
||||
|
||||
steps:
|
||||
- name: Prepare
|
||||
@@ -58,9 +70,6 @@ jobs:
|
||||
username: ${{ env.DOCKERHUB_USER }}
|
||||
password: ${{ env.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -74,7 +83,8 @@ jobs:
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
context: ${{ matrix.build_context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
@@ -93,7 +103,7 @@ jobs:
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
|
||||
with:
|
||||
name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }}
|
||||
name: digests-${{ matrix.artifact_context }}-${{ env.PLATFORM_PAIR }}
|
||||
path: /tmp/digests/*
|
||||
if-no-files-found: error
|
||||
retention-days: 1
|
||||
|
||||
30
.github/workflows/docker-build.yml
vendored
30
.github/workflows/docker-build.yml
vendored
@@ -6,7 +6,12 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- api/Dockerfile
|
||||
- web/docker/**
|
||||
- web/Dockerfile
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
- .nvmrc
|
||||
|
||||
concurrency:
|
||||
group: docker-build-${{ github.head_ref || github.run_id }}
|
||||
@@ -14,26 +19,31 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build-docker:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ${{ matrix.runs_on }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- service_name: "api-amd64"
|
||||
platform: linux/amd64
|
||||
context: "api"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "api-arm64"
|
||||
platform: linux/arm64
|
||||
context: "api"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}:api"
|
||||
file: "Dockerfile"
|
||||
- service_name: "web-amd64"
|
||||
platform: linux/amd64
|
||||
context: "web"
|
||||
runs_on: ubuntu-latest
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
- service_name: "web-arm64"
|
||||
platform: linux/arm64
|
||||
context: "web"
|
||||
runs_on: ubuntu-24.04-arm
|
||||
context: "{{defaultContext}}"
|
||||
file: "web/Dockerfile"
|
||||
steps:
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0
|
||||
|
||||
@@ -41,8 +51,8 @@ jobs:
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0
|
||||
with:
|
||||
push: false
|
||||
context: "{{defaultContext}}:${{ matrix.context }}"
|
||||
file: "${{ matrix.file }}"
|
||||
context: ${{ matrix.context }}
|
||||
file: ${{ matrix.file }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
8
.github/workflows/main-ci.yml
vendored
8
.github/workflows/main-ci.yml
vendored
@@ -65,6 +65,10 @@ jobs:
|
||||
- 'docker/volumes/sandbox/conf/**'
|
||||
web:
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- '.github/workflows/web-tests.yml'
|
||||
- '.github/actions/setup-web/**'
|
||||
e2e:
|
||||
@@ -73,6 +77,10 @@ jobs:
|
||||
- 'api/uv.lock'
|
||||
- 'e2e/**'
|
||||
- 'web/**'
|
||||
- 'package.json'
|
||||
- 'pnpm-lock.yaml'
|
||||
- 'pnpm-workspace.yaml'
|
||||
- '.nvmrc'
|
||||
- 'docker/docker-compose.middleware.yaml'
|
||||
- 'docker/middleware.env.example'
|
||||
- '.github/workflows/web-e2e.yml'
|
||||
|
||||
13
.github/workflows/pyrefly-diff.yml
vendored
13
.github/workflows/pyrefly-diff.yml
vendored
@@ -50,6 +50,17 @@ jobs:
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Check if line counts match
|
||||
id: line_count_check
|
||||
run: |
|
||||
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
|
||||
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
|
||||
if [ "$base_lines" -eq "$pr_lines" ]; then
|
||||
echo "same=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "same=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
@@ -63,7 +74,7 @@ jobs:
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
8
.github/workflows/style.yml
vendored
8
.github/workflows/style.yml
vendored
@@ -77,6 +77,10 @@ jobs:
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
package.json
|
||||
pnpm-lock.yaml
|
||||
pnpm-workspace.yaml
|
||||
.nvmrc
|
||||
.github/workflows/style.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
@@ -90,9 +94,9 @@ jobs:
|
||||
uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4
|
||||
with:
|
||||
path: web/.eslintcache
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
key: ${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-${{ github.sha }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'web/pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
${{ runner.os }}-web-eslint-${{ hashFiles('web/package.json', 'pnpm-lock.yaml', 'web/eslint.config.mjs', 'web/eslint.constants.mjs', 'web/plugins/eslint/**') }}-
|
||||
|
||||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
3
.github/workflows/tool-test-sdks.yaml
vendored
3
.github/workflows/tool-test-sdks.yaml
vendored
@@ -6,6 +6,9 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- sdks/**
|
||||
- package.json
|
||||
- pnpm-lock.yaml
|
||||
- pnpm-workspace.yaml
|
||||
|
||||
concurrency:
|
||||
group: sdk-tests-${{ github.head_ref || github.run_id }}
|
||||
|
||||
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
95
.github/workflows/vdb-tests-full.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
name: Run Full VDB Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 3 * * 1'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-full-${{ github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Full VDB Tests
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Free Disk Space
|
||||
uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2
|
||||
with:
|
||||
remove_dotnet: true
|
||||
remove_haskell: true
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache-dependency-glob: api/uv.lock
|
||||
|
||||
- name: Check UV lockfile
|
||||
run: uv lock --project api --check
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --project api --dev
|
||||
|
||||
- name: Set up dotenvs
|
||||
run: |
|
||||
cp docker/.env.example docker/.env
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
|
||||
- name: Expose Service Ports
|
||||
run: sh .github/workflows/expose_service_ports.sh
|
||||
|
||||
# - name: Set up Vector Store (TiDB)
|
||||
# uses: hoverkraft-tech/compose-action@v2.0.2
|
||||
# with:
|
||||
# compose-file: docker/tidb/docker-compose.yaml
|
||||
# services: |
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Full Vector Store Matrix
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
echo $(pwd)
|
||||
ls -lah .
|
||||
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
|
||||
|
||||
# - name: Check VDB Ready (TiDB)
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
25
.github/workflows/vdb-tests.yml
vendored
25
.github/workflows/vdb-tests.yml
vendored
@@ -1,15 +1,18 @@
|
||||
name: Run VDB Tests
|
||||
name: Run VDB Smoke Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: vdb-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: VDB Tests
|
||||
name: VDB Smoke Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -58,23 +61,18 @@ jobs:
|
||||
# tidb
|
||||
# tiflash
|
||||
|
||||
- name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase)
|
||||
- name: Set up Vector Stores for Smoke Coverage
|
||||
uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0
|
||||
with:
|
||||
compose-file: |
|
||||
docker/docker-compose.yaml
|
||||
services: |
|
||||
db_postgres
|
||||
redis
|
||||
weaviate
|
||||
qdrant
|
||||
couchbase-server
|
||||
etcd
|
||||
minio
|
||||
milvus-standalone
|
||||
pgvecto-rs
|
||||
pgvector
|
||||
chroma
|
||||
elasticsearch
|
||||
oceanbase
|
||||
|
||||
- name: setup test config
|
||||
run: |
|
||||
@@ -86,4 +84,9 @@ jobs:
|
||||
# run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
|
||||
|
||||
- name: Test Vector Stores
|
||||
run: uv run --project api bash dev/pytest/pytest_vdb.sh
|
||||
run: |
|
||||
uv run --project api pytest --timeout "${PYTEST_TIMEOUT:-180}" \
|
||||
api/tests/integration_tests/vdb/chroma \
|
||||
api/tests/integration_tests/vdb/pgvector \
|
||||
api/tests/integration_tests/vdb/qdrant \
|
||||
api/tests/integration_tests/vdb/weaviate
|
||||
|
||||
4
.github/workflows/web-e2e.yml
vendored
4
.github/workflows/web-e2e.yml
vendored
@@ -27,10 +27,6 @@ jobs:
|
||||
- name: Setup web dependencies
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Install E2E package dependencies
|
||||
working-directory: ./e2e
|
||||
run: vp install --frozen-lockfile
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@cec208311dfd045dd5311c1add060b2062131d57 # v8.0.0
|
||||
with:
|
||||
|
||||
31
.github/workflows/web-tests.yml
vendored
31
.github/workflows/web-tests.yml
vendored
@@ -89,34 +89,3 @@ jobs:
|
||||
flags: web
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }}
|
||||
|
||||
web-build:
|
||||
name: Web Build
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5
|
||||
with:
|
||||
files: |
|
||||
web/**
|
||||
.github/workflows/web-tests.yml
|
||||
.github/actions/setup-web/**
|
||||
|
||||
- name: Setup web environment
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Web build check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: vp run build
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -212,6 +212,7 @@ api/.vscode
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
/node_modules
|
||||
|
||||
# plugin migrate
|
||||
plugins.jsonl
|
||||
@@ -239,4 +240,4 @@ scripts/stress-test/reports/
|
||||
*.local.md
|
||||
|
||||
# Code Agent Folder
|
||||
.qoder/*
|
||||
.qoder/*
|
||||
|
||||
6
Makefile
6
Makefile
@@ -24,8 +24,8 @@ prepare-docker:
|
||||
# Step 2: Prepare web environment
|
||||
prepare-web:
|
||||
@echo "🌐 Setting up web environment..."
|
||||
@cp -n web/.env.example web/.env 2>/dev/null || echo "Web .env already exists"
|
||||
@cd web && pnpm install
|
||||
@cp -n web/.env.example web/.env.local 2>/dev/null || echo "Web .env.local already exists"
|
||||
@pnpm install
|
||||
@echo "✅ Web environment prepared (not started)"
|
||||
|
||||
# Step 3: Prepare API environment
|
||||
@@ -93,7 +93,7 @@ test:
|
||||
# Build Docker images
|
||||
build-web:
|
||||
@echo "Building web Docker image: $(WEB_IMAGE):$(VERSION)..."
|
||||
docker build -t $(WEB_IMAGE):$(VERSION) ./web
|
||||
docker build -f web/Dockerfile -t $(WEB_IMAGE):$(VERSION) .
|
||||
@echo "Web Docker image built successfully: $(WEB_IMAGE):$(VERSION)"
|
||||
|
||||
build-api:
|
||||
|
||||
@@ -40,6 +40,8 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-web
|
||||
```
|
||||
|
||||
`./dev/setup` and `./dev/start-web` install JavaScript dependencies through the repository root workspace, so you do not need a separate `cd web && pnpm install` step.
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
|
||||
@@ -287,12 +287,10 @@ class ModelProviderModelCredentialApi(Resource):
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
# Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM)
|
||||
normalized_model_type = args.model_type.to_origin_model_type()
|
||||
available_credentials = model_provider_service.get_provider_model_available_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model_type=normalized_model_type,
|
||||
model_type=args.model_type,
|
||||
model=args.model,
|
||||
)
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ class MCPAppApi(Resource):
|
||||
required=variable.get("required", False),
|
||||
max_length=variable.get("max_length"),
|
||||
options=variable.get("options") or [],
|
||||
json_schema=variable.get("json_schema"),
|
||||
)
|
||||
|
||||
def _parse_mcp_request(self, args: dict) -> mcp_types.ClientRequest | mcp_types.ClientNotification:
|
||||
|
||||
@@ -29,6 +29,31 @@ from services.entities.knowledge_entities.knowledge_entities import SegmentUpdat
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
def _marshal_segment_with_summary(segment, dataset_id: str) -> dict:
|
||||
"""Marshal a single segment and enrich it with summary content."""
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
def _marshal_segments_with_summary(segments, dataset_id: str) -> list[dict]:
|
||||
"""Marshal multiple segments and enrich them with summary content (batch query)."""
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries: dict = {}
|
||||
if segment_ids:
|
||||
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
|
||||
summaries = {chunk_id: record.summary_content for chunk_id, record in summary_records.items()}
|
||||
|
||||
result = []
|
||||
for segment in segments:
|
||||
segment_dict = dict(marshal(segment, segment_fields)) # type: ignore[arg-type]
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
result.append(segment_dict)
|
||||
return result
|
||||
|
||||
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
@@ -132,7 +157,7 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@@ -196,7 +221,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments, segment_fields),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@@ -296,7 +321,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@@ -326,7 +351,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
|
||||
@@ -81,7 +81,7 @@ def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LL
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
|
||||
@@ -8,6 +8,7 @@ associates with the node span.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextvars import Token
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, final
|
||||
|
||||
@@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass(slots=True)
|
||||
class _NodeSpanContext:
|
||||
span: "Span"
|
||||
token: object
|
||||
token: Token[context_api.Context]
|
||||
|
||||
|
||||
@final
|
||||
|
||||
@@ -403,7 +403,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -753,7 +753,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModel.tenant_id == self.tenant_id,
|
||||
ProviderModel.provider_name.in_(provider_names),
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
|
||||
return session.execute(stmt).scalar_one_or_none()
|
||||
@@ -778,7 +778,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
@@ -825,7 +825,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
ProviderModelCredential.credential_name == credential_name,
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -901,7 +901,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
@@ -970,7 +970,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
credential_name=credential_name,
|
||||
)
|
||||
@@ -983,7 +983,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
credential_id=credential.id,
|
||||
is_valid=True,
|
||||
)
|
||||
@@ -1038,7 +1038,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1083,7 +1083,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1116,7 +1116,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||
session.delete(credential_record)
|
||||
@@ -1156,7 +1156,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1171,7 +1171,7 @@ class ProviderConfiguration(BaseModel):
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
is_valid=True,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
@@ -1207,7 +1207,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
@@ -1263,7 +1263,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(ProviderModelSetting).where(
|
||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||
ProviderModelSetting.model_type == model_type,
|
||||
ProviderModelSetting.model_name == model,
|
||||
)
|
||||
return session.execute(stmt).scalars().first()
|
||||
@@ -1286,7 +1286,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=True,
|
||||
)
|
||||
@@ -1312,7 +1312,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
enabled=False,
|
||||
)
|
||||
@@ -1348,7 +1348,7 @@ class ProviderConfiguration(BaseModel):
|
||||
stmt = select(func.count(LoadBalancingModelConfig.id)).where(
|
||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||
LoadBalancingModelConfig.provider_name.in_(provider_names),
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
load_balancing_config_count = session.execute(stmt).scalar() or 0
|
||||
@@ -1364,7 +1364,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
@@ -1391,7 +1391,7 @@ class ProviderConfiguration(BaseModel):
|
||||
model_setting = ProviderModelSetting(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
|
||||
@@ -260,4 +260,12 @@ def convert_input_form_to_parameters(
|
||||
parameters[item.variable]["enum"] = item.options
|
||||
elif item.type == VariableEntityType.NUMBER:
|
||||
parameters[item.variable]["type"] = "number"
|
||||
elif item.type == VariableEntityType.CHECKBOX:
|
||||
parameters[item.variable]["type"] = "boolean"
|
||||
elif item.type == VariableEntityType.JSON_OBJECT:
|
||||
parameters[item.variable]["type"] = "object"
|
||||
if item.json_schema:
|
||||
for key in ("properties", "required", "additionalProperties"):
|
||||
if key in item.json_schema:
|
||||
parameters[item.variable][key] = item.json_schema[key]
|
||||
return parameters, required
|
||||
|
||||
@@ -16,7 +16,13 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.util.instrumentation import InstrumentationScope
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import Link, SpanContext, TraceFlags
|
||||
|
||||
from configs import dify_config
|
||||
@@ -45,10 +51,10 @@ class TraceClient:
|
||||
self.endpoint = endpoint
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
ACS_ARMS_SERVICE_FEATURE: "genai_app",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk import trace as trace_sdk
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes
|
||||
from opentelemetry.semconv.attributes import exception_attributes
|
||||
from opentelemetry.trace import Span, Status, StatusCode, set_span_in_context, use_span
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
@@ -134,10 +134,10 @@ def set_span_status(current_span: Span, error: Exception | str | None = None):
|
||||
if not exception_message:
|
||||
exception_message = repr(error)
|
||||
attributes: dict[str, AttributeValue] = {
|
||||
OTELSpanAttributes.EXCEPTION_TYPE: exception_type,
|
||||
OTELSpanAttributes.EXCEPTION_MESSAGE: exception_message,
|
||||
OTELSpanAttributes.EXCEPTION_ESCAPED: False,
|
||||
OTELSpanAttributes.EXCEPTION_STACKTRACE: error_string,
|
||||
exception_attributes.EXCEPTION_TYPE: exception_type,
|
||||
exception_attributes.EXCEPTION_MESSAGE: exception_message,
|
||||
exception_attributes.EXCEPTION_ESCAPED: False,
|
||||
exception_attributes.EXCEPTION_STACKTRACE: error_string,
|
||||
}
|
||||
current_span.add_event(name="exception", attributes=attributes)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from langfuse import Langfuse
|
||||
from langfuse.api import (
|
||||
CreateGenerationBody,
|
||||
CreateSpanBody,
|
||||
IngestionEvent_GenerationCreate,
|
||||
IngestionEvent_SpanCreate,
|
||||
IngestionEvent_TraceCreate,
|
||||
TraceBody,
|
||||
)
|
||||
from langfuse.api.commons.types.usage import Usage
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -396,18 +406,61 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
)
|
||||
self.add_span(langfuse_span_data=name_generation_span_data)
|
||||
|
||||
def _make_event_id(self) -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _now_iso(self) -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
def add_trace(self, langfuse_trace_data: LangfuseTrace | None = None):
|
||||
format_trace_data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
data = filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
|
||||
try:
|
||||
self.langfuse_client.trace(**format_trace_data)
|
||||
body = TraceBody(
|
||||
id=data.get("id"),
|
||||
name=data.get("name"),
|
||||
user_id=data.get("user_id"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
session_id=data.get("session_id"),
|
||||
version=data.get("version"),
|
||||
release=data.get("release"),
|
||||
tags=data.get("tags"),
|
||||
public=data.get("public"),
|
||||
)
|
||||
event = IngestionEvent_TraceCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Trace created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
|
||||
|
||||
def add_span(self, langfuse_span_data: LangfuseSpan | None = None):
|
||||
format_span_data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
data = filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
|
||||
try:
|
||||
self.langfuse_client.span(**format_span_data)
|
||||
body = CreateSpanBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
)
|
||||
event = IngestionEvent_SpanCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Span created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create span: {str(e)}")
|
||||
@@ -418,11 +471,45 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
span.end(**format_span_data)
|
||||
|
||||
def add_generation(self, langfuse_generation_data: LangfuseGeneration | None = None):
|
||||
format_generation_data = (
|
||||
filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
)
|
||||
data = filter_none_values(langfuse_generation_data.model_dump()) if langfuse_generation_data else {}
|
||||
try:
|
||||
self.langfuse_client.generation(**format_generation_data)
|
||||
usage_data = data.pop("usage", None)
|
||||
usage = None
|
||||
if usage_data:
|
||||
usage = Usage(
|
||||
input=usage_data.get("input", 0) or 0,
|
||||
output=usage_data.get("output", 0) or 0,
|
||||
total=usage_data.get("total", 0) or 0,
|
||||
unit=usage_data.get("unit"),
|
||||
input_cost=usage_data.get("inputCost"),
|
||||
output_cost=usage_data.get("outputCost"),
|
||||
total_cost=usage_data.get("totalCost"),
|
||||
)
|
||||
|
||||
body = CreateGenerationBody(
|
||||
id=data.get("id"),
|
||||
trace_id=data.get("trace_id"),
|
||||
name=data.get("name"),
|
||||
start_time=data.get("start_time"),
|
||||
end_time=data.get("end_time"),
|
||||
model=data.get("model"),
|
||||
model_parameters=data.get("model_parameters"),
|
||||
input=data.get("input"),
|
||||
output=data.get("output"),
|
||||
usage=usage,
|
||||
metadata=data.get("metadata"),
|
||||
level=data.get("level"),
|
||||
status_message=data.get("status_message"),
|
||||
parent_observation_id=data.get("parent_observation_id"),
|
||||
version=data.get("version"),
|
||||
completion_start_time=data.get("completion_start_time"),
|
||||
)
|
||||
event = IngestionEvent_GenerationCreate(
|
||||
body=body,
|
||||
id=self._make_event_id(),
|
||||
timestamp=self._now_iso(),
|
||||
)
|
||||
self.langfuse_client.api.ingestion.batch(batch=[event])
|
||||
logger.debug("LangFuse Generation created successfully")
|
||||
except Exception as e:
|
||||
raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
|
||||
@@ -443,7 +530,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
|
||||
def get_project_key(self):
|
||||
try:
|
||||
projects = self.langfuse_client.client.projects.get()
|
||||
projects = self.langfuse_client.api.projects.get()
|
||||
return projects.data[0].id
|
||||
except Exception as e:
|
||||
logger.debug("LangFuse get project key failed: %s", str(e))
|
||||
|
||||
@@ -26,7 +26,13 @@ from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExport
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.deployment_attributes import ( # type: ignore[import-untyped]
|
||||
DEPLOYMENT_ENVIRONMENT,
|
||||
)
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import SpanKind
|
||||
from opentelemetry.util.types import AttributeValue
|
||||
|
||||
@@ -73,13 +79,13 @@ class TencentTraceClient:
|
||||
|
||||
self.resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
ResourceAttributes.TELEMETRY_SDK_LANGUAGE: "python",
|
||||
ResourceAttributes.TELEMETRY_SDK_NAME: "opentelemetry",
|
||||
ResourceAttributes.TELEMETRY_SDK_VERSION: _get_opentelemetry_sdk_version(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
service_attributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
|
||||
DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
|
||||
HOST_NAME: socket.gethostname(),
|
||||
"telemetry.sdk.language": "python",
|
||||
"telemetry.sdk.name": "opentelemetry",
|
||||
"telemetry.sdk.version": _get_opentelemetry_sdk_version(),
|
||||
}
|
||||
)
|
||||
# Prepare gRPC endpoint/metadata
|
||||
|
||||
@@ -306,7 +306,7 @@ class ProviderManager:
|
||||
"""
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
TenantDefaultModel.model_type == model_type,
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
@@ -324,7 +324,7 @@ class ProviderManager:
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
provider_name=available_model.provider.provider,
|
||||
model_name=available_model.model,
|
||||
)
|
||||
@@ -391,7 +391,7 @@ class ProviderManager:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
TenantDefaultModel.model_type == model_type,
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
@@ -405,7 +405,7 @@ class ProviderManager:
|
||||
# create default model
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
provider_name=provider,
|
||||
model_name=model,
|
||||
)
|
||||
@@ -626,9 +626,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
provider_quota_to_provider_record_dict[provider_record.quota_type] = provider_record
|
||||
|
||||
for quota in configuration.quotas:
|
||||
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
|
||||
@@ -641,7 +640,7 @@ class ProviderManager:
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
provider_name=ModelProviderID(provider_name).provider_name,
|
||||
provider_type=ProviderType.SYSTEM,
|
||||
quota_type=quota.quota_type,
|
||||
quota_type=quota.quota_type, # type: ignore[arg-type]
|
||||
quota_limit=0, # type: ignore
|
||||
quota_used=0,
|
||||
is_valid=True,
|
||||
@@ -823,7 +822,7 @@ class ProviderManager:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
model_type=provider_model_record.model_type,
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
current_credential_name=provider_model_record.credential_name,
|
||||
@@ -921,9 +920,8 @@ class ProviderManager:
|
||||
if provider_record.provider_type != ProviderType.SYSTEM:
|
||||
continue
|
||||
|
||||
quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = (
|
||||
provider_record
|
||||
)
|
||||
if provider_record.quota_type is not None:
|
||||
quota_type_to_provider_records_dict[provider_record.quota_type] = provider_record # type: ignore[index]
|
||||
quota_configurations = []
|
||||
|
||||
if dify_config.EDITION == "CLOUD":
|
||||
@@ -1203,7 +1201,7 @@ class ProviderManager:
|
||||
model_settings.append(
|
||||
ModelSettings(
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
model_type=provider_model_setting.model_type,
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
|
||||
@@ -27,7 +27,10 @@ from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
from opentelemetry.semconv._incubating.attributes.host_attributes import ( # type: ignore[import-untyped]
|
||||
HOST_NAME,
|
||||
)
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.trace import SpanContext, TraceFlags
|
||||
from opentelemetry.util.types import Attributes, AttributeValue
|
||||
|
||||
@@ -114,8 +117,8 @@ class EnterpriseExporter:
|
||||
|
||||
resource = Resource(
|
||||
attributes={
|
||||
ResourceAttributes.SERVICE_NAME: service_name,
|
||||
ResourceAttributes.HOST_NAME: socket.gethostname(),
|
||||
service_attributes.SERVICE_NAME: service_name,
|
||||
HOST_NAME: socket.gethostname(),
|
||||
}
|
||||
)
|
||||
sampler = ParentBasedTraceIdRatio(sampling_rate)
|
||||
|
||||
@@ -157,7 +157,7 @@ def handle(sender: Message, **kwargs):
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
|
||||
@@ -6,15 +6,24 @@ def init_app(app: DifyApp):
|
||||
if dify_config.SENTRY_DSN:
|
||||
import sentry_sdk
|
||||
from graphon.model_runtime.errors.invoke import InvokeRateLimitError
|
||||
from langfuse import parse_error
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
try:
|
||||
from langfuse._utils import parse_error
|
||||
|
||||
_langfuse_error_response = parse_error.defaultErrorResponse
|
||||
except (ImportError, AttributeError):
|
||||
_langfuse_error_response = (
|
||||
"Unexpected error occurred. Please check your request"
|
||||
" and contact support: https://langfuse.com/support."
|
||||
)
|
||||
|
||||
def before_send(event, hint):
|
||||
if "exc_info" in hint:
|
||||
_, exc_value, _ = hint["exc_info"]
|
||||
if parse_error.defaultErrorResponse in str(exc_value):
|
||||
if _langfuse_error_response in str(exc_value):
|
||||
return None
|
||||
|
||||
return event
|
||||
@@ -27,7 +36,7 @@ def init_app(app: DifyApp):
|
||||
ValueError,
|
||||
FileNotFoundError,
|
||||
InvokeRateLimitError,
|
||||
parse_error.defaultErrorResponse,
|
||||
_langfuse_error_response,
|
||||
],
|
||||
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
|
||||
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, cast
|
||||
|
||||
import flask
|
||||
from opentelemetry.instrumentation.celery import CeleryInstrumentor
|
||||
@@ -21,6 +23,38 @@ from extensions.otel.runtime import is_celery_worker
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SupportsInstrument(Protocol):
|
||||
def instrument(self, **kwargs: object) -> None: ...
|
||||
|
||||
|
||||
class SupportsFlaskInstrumentor(Protocol):
|
||||
def instrument_app(
|
||||
self, app: DifyApp, response_hook: Callable[[Span, str, list], None] | None = None, **kwargs: object
|
||||
) -> None: ...
|
||||
|
||||
|
||||
# Some OpenTelemetry instrumentor constructors are typed loosely enough that
|
||||
# pyrefly infers `NoneType`. Narrow the instances to just the methods we use
|
||||
# while leaving runtime behavior unchanged.
|
||||
def _new_celery_instrumentor() -> SupportsInstrument:
|
||||
return cast(
|
||||
SupportsInstrument,
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()),
|
||||
)
|
||||
|
||||
|
||||
def _new_httpx_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, HTTPXClientInstrumentor())
|
||||
|
||||
|
||||
def _new_redis_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, RedisInstrumentor())
|
||||
|
||||
|
||||
def _new_sqlalchemy_instrumentor() -> SupportsInstrument:
|
||||
return cast(SupportsInstrument, SQLAlchemyInstrumentor())
|
||||
|
||||
|
||||
class ExceptionLoggingHandler(logging.Handler):
|
||||
"""
|
||||
Handler that records exceptions to the current OpenTelemetry span.
|
||||
@@ -97,7 +131,7 @@ def init_flask_instrumentor(app: DifyApp) -> None:
|
||||
|
||||
from opentelemetry.instrumentation.flask import FlaskInstrumentor
|
||||
|
||||
instrumentor = FlaskInstrumentor()
|
||||
instrumentor = cast(SupportsFlaskInstrumentor, FlaskInstrumentor())
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Initializing Flask instrumentor")
|
||||
instrumentor.instrument_app(app, response_hook=response_hook)
|
||||
@@ -106,21 +140,21 @@ def init_flask_instrumentor(app: DifyApp) -> None:
|
||||
def init_sqlalchemy_instrumentor(app: DifyApp) -> None:
|
||||
with app.app_context():
|
||||
engines = list(app.extensions["sqlalchemy"].engines.values())
|
||||
SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
|
||||
_new_sqlalchemy_instrumentor().instrument(enable_commenter=True, engines=engines)
|
||||
|
||||
|
||||
def init_redis_instrumentor() -> None:
|
||||
RedisInstrumentor().instrument()
|
||||
_new_redis_instrumentor().instrument()
|
||||
|
||||
|
||||
def init_httpx_instrumentor() -> None:
|
||||
HTTPXClientInstrumentor().instrument()
|
||||
_new_httpx_instrumentor().instrument()
|
||||
|
||||
|
||||
def init_instruments(app: DifyApp) -> None:
|
||||
if not is_celery_worker():
|
||||
init_flask_instrumentor(app)
|
||||
CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
|
||||
_new_celery_instrumentor().instrument()
|
||||
|
||||
instrument_exception_logging()
|
||||
init_sqlalchemy_instrumentor(app)
|
||||
|
||||
@@ -6,6 +6,7 @@ from functools import cached_property
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import DateTime, String, func, select, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -13,7 +14,7 @@ from libs.uuid_utils import uuidv7
|
||||
|
||||
from .base import TypeBase
|
||||
from .engine import db
|
||||
from .enums import CredentialSourceType, PaymentStatus
|
||||
from .enums import CredentialSourceType, PaymentStatus, ProviderQuotaType
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
|
||||
@@ -29,24 +30,6 @@ class ProviderType(StrEnum):
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class ProviderQuotaType(StrEnum):
|
||||
PAID = auto()
|
||||
"""hosted paid quota"""
|
||||
|
||||
FREE = auto()
|
||||
"""third-party free quota"""
|
||||
|
||||
TRIAL = auto()
|
||||
"""hosted trial quota"""
|
||||
|
||||
@staticmethod
|
||||
def value_of(value: str) -> ProviderQuotaType:
|
||||
for member in ProviderQuotaType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class Provider(TypeBase):
|
||||
"""
|
||||
Provider model representing the API providers and their configurations.
|
||||
@@ -77,7 +60,9 @@ class Provider(TypeBase):
|
||||
last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
|
||||
quota_type: Mapped[str | None] = mapped_column(String(40), nullable=True, server_default=text("''"), default="")
|
||||
quota_type: Mapped[ProviderQuotaType | None] = mapped_column(
|
||||
EnumText(ProviderQuotaType, length=40), nullable=True, server_default=text("''"), default=None
|
||||
)
|
||||
quota_limit: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=None)
|
||||
quota_used: Mapped[int | None] = mapped_column(sa.BigInteger, nullable=True, default=0)
|
||||
|
||||
@@ -147,7 +132,7 @@ class ProviderModel(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@@ -189,7 +174,7 @@ class TenantDefaultModel(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||
)
|
||||
@@ -269,7 +254,7 @@ class ProviderModelSetting(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True)
|
||||
load_balancing_enabled: Mapped[bool] = mapped_column(
|
||||
sa.Boolean, nullable=False, server_default=text("false"), default=False
|
||||
@@ -299,7 +284,7 @@ class LoadBalancingModelConfig(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
@@ -364,7 +349,7 @@ class ProviderModelCredential(TypeBase):
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
model_type: Mapped[str] = mapped_column(String(40), nullable=False)
|
||||
model_type: Mapped[ModelType] = mapped_column(EnumText(ModelType, length=40), nullable=False)
|
||||
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
encrypted_config: Mapped[str] = mapped_column(LongText, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@@ -144,8 +144,8 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
return dialect.type_descriptor(VARCHAR(self._length))
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> _E | None:
|
||||
if value is None:
|
||||
return value
|
||||
if value is None or value == "":
|
||||
return None
|
||||
# Type annotation guarantees value is str at this point
|
||||
return self._enum_class(value)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ dependencies = [
|
||||
"httpx[socks]~=0.28.0",
|
||||
"jieba==0.42.1",
|
||||
"json-repair>=0.55.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langfuse>=3.0.0,<5.0.0",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
@@ -41,23 +41,23 @@ dependencies = [
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.6", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.28.0",
|
||||
"opentelemetry-instrumentation==0.49b0",
|
||||
"opentelemetry-instrumentation-celery==0.49b0",
|
||||
"opentelemetry-instrumentation-flask==0.49b0",
|
||||
"opentelemetry-instrumentation-httpx==0.49b0",
|
||||
"opentelemetry-instrumentation-redis==0.49b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
|
||||
"opentelemetry-api==1.40.0",
|
||||
"opentelemetry-distro==0.61b0",
|
||||
"opentelemetry-exporter-otlp==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.40.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.40.0",
|
||||
"opentelemetry-instrumentation==0.61b0",
|
||||
"opentelemetry-instrumentation-celery==0.61b0",
|
||||
"opentelemetry-instrumentation-flask==0.61b0",
|
||||
"opentelemetry-instrumentation-httpx==0.61b0",
|
||||
"opentelemetry-instrumentation-redis==0.61b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.61b0",
|
||||
"opentelemetry-propagator-b3==1.40.0",
|
||||
"opentelemetry-proto==1.28.0",
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"opentelemetry-proto==1.40.0",
|
||||
"opentelemetry-sdk==1.40.0",
|
||||
"opentelemetry-semantic-conventions==0.61b0",
|
||||
"opentelemetry-util-http==0.61b0",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
@@ -72,12 +72,12 @@ dependencies = [
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.4.0",
|
||||
"resend~=2.26.0",
|
||||
"sentry-sdk[flask]~=2.56.0",
|
||||
"sentry-sdk[flask]~=2.55.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==1.0.0",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.22.6",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"pypandoc~=1.13",
|
||||
"yarl~=1.23.0",
|
||||
"sseclient-py~=1.9.0",
|
||||
|
||||
@@ -115,7 +115,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
or_(
|
||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||
@@ -240,7 +240,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
@@ -288,7 +288,7 @@ class ModelLoadBalancingService:
|
||||
inherit_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_type=model_type,
|
||||
model_name=model,
|
||||
name="__inherit__",
|
||||
)
|
||||
@@ -328,7 +328,7 @@ class ModelLoadBalancingService:
|
||||
select(LoadBalancingModelConfig).where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
).all()
|
||||
@@ -368,7 +368,7 @@ class ModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_name=model,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -432,7 +432,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
model_name=model,
|
||||
name=credential_record.credential_name,
|
||||
encrypted_config=credential_record.encrypted_config,
|
||||
@@ -460,7 +460,7 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_configuration.provider.provider,
|
||||
model_type=model_type_enum.to_origin_model_type(),
|
||||
model_type=model_type_enum,
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials),
|
||||
@@ -515,7 +515,7 @@ class ModelLoadBalancingService:
|
||||
.where(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant(flask_req_ctx):
|
||||
with session_factory.create_session() as session:
|
||||
t = Tenant(name="plugin_it_tenant")
|
||||
session.add(t)
|
||||
session.commit()
|
||||
tenant_id = t.id
|
||||
|
||||
yield tenant_id
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id))
|
||||
session.execute(
|
||||
delete(TenantPluginAutoUpgradeStrategy).where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
)
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
class TestPluginPermissionLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginPermissionService.get_permission(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.ADMINS,
|
||||
TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
PluginPermissionService.change_permission(
|
||||
tenant,
|
||||
TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
perm = PluginPermissionService.get_permission(tenant)
|
||||
assert perm is not None
|
||||
assert perm.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
|
||||
assert perm.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
count = session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant).count()
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestPluginAutoUpgradeLifecycle:
|
||||
def test_get_returns_none_for_new_tenant(self, tenant):
|
||||
assert PluginAutoUpgradeService.get_strategy(tenant) is None
|
||||
|
||||
def test_change_creates_row(self, tenant):
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=3,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
assert result is True
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 3
|
||||
|
||||
def test_change_updates_existing_row(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=12,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["plugin-a"],
|
||||
)
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert strategy.upgrade_time_of_day == 12
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert strategy.include_plugins == ["plugin-a"]
|
||||
|
||||
def test_exclude_plugin_creates_strategy_when_none_exists(self, tenant):
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "my-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "my-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_appends_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["existing"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "new-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "existing" in strategy.exclude_plugins
|
||||
assert "new-plugin" in strategy.exclude_plugins
|
||||
|
||||
def test_exclude_plugin_dedup_in_exclude_mode(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
exclude_plugins=["same-plugin"],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "same-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.exclude_plugins.count("same-plugin") == 1
|
||||
|
||||
def test_exclude_from_partial_mode_removes_from_include(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=["p1", "p2"],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "p1")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert "p1" not in strategy.include_plugins
|
||||
assert "p2" in strategy.include_plugins
|
||||
|
||||
def test_exclude_from_all_mode_switches_to_exclude(self, tenant):
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant,
|
||||
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
upgrade_time_of_day=0,
|
||||
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
exclude_plugins=[],
|
||||
include_plugins=[],
|
||||
)
|
||||
PluginAutoUpgradeService.exclude_plugin(tenant, "excluded-plugin")
|
||||
|
||||
strategy = PluginAutoUpgradeService.get_strategy(tenant)
|
||||
assert strategy is not None
|
||||
assert strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
assert "excluded-plugin" in strategy.exclude_plugins
|
||||
@@ -0,0 +1,348 @@
|
||||
import datetime
|
||||
import math
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from models import Tenant
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
Conversation,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
MessageFeedback,
|
||||
)
|
||||
from services.retention.conversation.messages_clean_policy import BillingDisabledPolicy
|
||||
from services.retention.conversation.messages_clean_service import MessagesCleanService
|
||||
|
||||
_NOW = datetime.datetime(2026, 1, 15, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
_OLD = _NOW - datetime.timedelta(days=60)
|
||||
_VERY_OLD = _NOW - datetime.timedelta(days=90)
|
||||
_RECENT = _NOW - datetime.timedelta(days=5)
|
||||
|
||||
_WINDOW_START = _VERY_OLD - datetime.timedelta(hours=1)
|
||||
_WINDOW_END = _RECENT + datetime.timedelta(hours=1)
|
||||
|
||||
_DEFAULT_BATCH_SIZE = 100
|
||||
_PAGINATION_MESSAGE_COUNT = 25
|
||||
_PAGINATION_BATCH_SIZE = 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_and_app(flask_req_ctx):
|
||||
"""Creates a Tenant, App and Conversation for the test and cleans up after."""
|
||||
with session_factory.create_session() as session:
|
||||
tenant = Tenant(name="retention_it_tenant")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="Retention IT App",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conv = Conversation(
|
||||
app_id=app.id,
|
||||
mode="chat",
|
||||
name="test_conv",
|
||||
status="normal",
|
||||
from_source="console",
|
||||
_inputs={},
|
||||
)
|
||||
session.add(conv)
|
||||
session.commit()
|
||||
|
||||
tenant_id = tenant.id
|
||||
app_id = app.id
|
||||
conv_id = conv.id
|
||||
|
||||
yield {"tenant_id": tenant_id, "app_id": app_id, "conversation_id": conv_id}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Conversation).where(Conversation.id == conv_id))
|
||||
session.execute(delete(App).where(App.id == app_id))
|
||||
session.execute(delete(Tenant).where(Tenant.id == tenant_id))
|
||||
session.commit()
|
||||
|
||||
|
||||
def _make_message(app_id: str, conversation_id: str, created_at: datetime.datetime) -> Message:
|
||||
return Message(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
query="test",
|
||||
message=[{"text": "hello"}],
|
||||
answer="world",
|
||||
message_tokens=1,
|
||||
message_unit_price=0,
|
||||
answer_tokens=1,
|
||||
answer_unit_price=0,
|
||||
from_source="console",
|
||||
currency="USD",
|
||||
_inputs={},
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
class TestMessagesCleanServiceIntegration:
|
||||
@pytest.fixture
|
||||
def seed_messages(self, tenant_and_app):
|
||||
"""Seeds one message at each of _VERY_OLD, _OLD, and _RECENT.
|
||||
Yields a semantic mapping keyed by age label.
|
||||
"""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
# Ordered tuple of (label, timestamp) for deterministic seeding
|
||||
timestamps = [
|
||||
("very_old", _VERY_OLD),
|
||||
("old", _OLD),
|
||||
("recent", _RECENT),
|
||||
]
|
||||
msg_ids: dict[str, str] = {}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for label, ts in timestamps:
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids[label] = msg.id
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(
|
||||
delete(Message)
|
||||
.where(Message.id.in_(list(msg_ids.values())))
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def paginated_seed_messages(self, tenant_and_app):
|
||||
"""Seeds multiple messages separated by 1-second increments starting at _OLD."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
msg_ids: list[str] = []
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
for i in range(_PAGINATION_MESSAGE_COUNT):
|
||||
ts = _OLD + datetime.timedelta(seconds=i)
|
||||
msg = _make_message(app_id, conv_id, ts)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
msg_ids.append(msg.id)
|
||||
session.commit()
|
||||
|
||||
yield {"msg_ids": msg_ids, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(Message).where(Message.id.in_(msg_ids)).execution_options(synchronize_session=False))
|
||||
session.commit()
|
||||
|
||||
@pytest.fixture
|
||||
def cascade_test_data(self, tenant_and_app):
|
||||
"""Seeds one Message with an associated Feedback and Annotation."""
|
||||
data = tenant_and_app
|
||||
app_id = data["app_id"]
|
||||
conv_id = data["conversation_id"]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
msg = _make_message(app_id, conv_id, _OLD)
|
||||
session.add(msg)
|
||||
session.flush()
|
||||
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
)
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app_id,
|
||||
conversation_id=conv_id,
|
||||
message_id=msg.id,
|
||||
question="q",
|
||||
content="a",
|
||||
account_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add_all([feedback, annotation])
|
||||
session.commit()
|
||||
|
||||
msg_id = msg.id
|
||||
fb_id = feedback.id
|
||||
ann_id = annotation.id
|
||||
|
||||
yield {"msg_id": msg_id, "fb_id": fb_id, "ann_id": ann_id, **data}
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.id == ann_id))
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.id == fb_id))
|
||||
session.execute(delete(Message).where(Message.id == msg_id))
|
||||
session.commit()
|
||||
|
||||
def test_dry_run_does_not_delete(self, seed_messages):
|
||||
"""Dry-run must count eligible rows without deleting any of them."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["filtered_messages"] == len(all_ids)
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == len(all_ids)
|
||||
|
||||
def test_billing_disabled_deletes_all_in_range(self, seed_messages):
|
||||
"""All 3 seeded messages fall within the window and must be deleted."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
all_ids = list(msg_ids.values())
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_WINDOW_START,
|
||||
end_before=_WINDOW_END,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
dry_run=False,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == len(all_ids)
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(all_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_start_from_filters_correctly(self, seed_messages):
|
||||
"""Only the message at _OLD falls within the narrow ±1 h window."""
|
||||
data = seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
start = _OLD - datetime.timedelta(hours=1)
|
||||
end = _OLD + datetime.timedelta(hours=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=start,
|
||||
end_before=end,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
all_ids = list(msg_ids.values())
|
||||
remaining_ids = {r[0] for r in session.query(Message.id).where(Message.id.in_(all_ids)).all()}
|
||||
|
||||
assert msg_ids["old"] not in remaining_ids
|
||||
assert msg_ids["very_old"] in remaining_ids
|
||||
assert msg_ids["recent"] in remaining_ids
|
||||
|
||||
def test_cursor_pagination_across_batches(self, paginated_seed_messages):
|
||||
"""Messages must be deleted across multiple batches."""
|
||||
data = paginated_seed_messages
|
||||
msg_ids = data["msg_ids"]
|
||||
|
||||
# _OLD is the earliest; the last one is _OLD + (_PAGINATION_MESSAGE_COUNT - 1) s.
|
||||
pagination_window_start = _OLD - datetime.timedelta(seconds=1)
|
||||
pagination_window_end = _OLD + datetime.timedelta(seconds=_PAGINATION_MESSAGE_COUNT)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=pagination_window_start,
|
||||
end_before=pagination_window_end,
|
||||
batch_size=_PAGINATION_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == _PAGINATION_MESSAGE_COUNT
|
||||
expected_batches = math.ceil(_PAGINATION_MESSAGE_COUNT / _PAGINATION_BATCH_SIZE)
|
||||
assert stats["batches"] >= expected_batches
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
remaining = session.query(Message).where(Message.id.in_(msg_ids)).count()
|
||||
assert remaining == 0
|
||||
|
||||
def test_no_messages_in_range_returns_empty_stats(self, seed_messages):
|
||||
"""A window entirely in the future must yield zero matches."""
|
||||
far_future = _NOW + datetime.timedelta(days=365)
|
||||
even_further = far_future + datetime.timedelta(days=1)
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=far_future,
|
||||
end_before=even_further,
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_relation_cascade_deletes(self, cascade_test_data):
|
||||
"""Deleting a Message must cascade to its Feedback and Annotation rows."""
|
||||
data = cascade_test_data
|
||||
msg_id = data["msg_id"]
|
||||
fb_id = data["fb_id"]
|
||||
ann_id = data["ann_id"]
|
||||
|
||||
svc = MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD - datetime.timedelta(hours=1),
|
||||
end_before=_OLD + datetime.timedelta(hours=1),
|
||||
batch_size=_DEFAULT_BATCH_SIZE,
|
||||
)
|
||||
stats = svc.run()
|
||||
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
assert session.query(Message).where(Message.id == msg_id).count() == 0
|
||||
assert session.query(MessageFeedback).where(MessageFeedback.id == fb_id).count() == 0
|
||||
assert session.query(MessageAnnotation).where(MessageAnnotation.id == ann_id).count() == 0
|
||||
|
||||
def test_factory_from_time_range_validation(self):
|
||||
with pytest.raises(ValueError, match="start_from"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_NOW,
|
||||
end_before=_OLD,
|
||||
)
|
||||
|
||||
def test_factory_from_days_validation(self):
|
||||
with pytest.raises(ValueError, match="days"):
|
||||
MessagesCleanService.from_days(
|
||||
policy=BillingDisabledPolicy(),
|
||||
days=-1,
|
||||
)
|
||||
|
||||
def test_factory_batch_size_validation(self):
|
||||
with pytest.raises(ValueError, match="batch_size"):
|
||||
MessagesCleanService.from_time_range(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=_OLD,
|
||||
end_before=_NOW,
|
||||
batch_size=0,
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import (
|
||||
ArchiveSummary,
|
||||
WorkflowRunArchiver,
|
||||
)
|
||||
from services.retention.workflow_run.constants import ARCHIVE_SCHEMA_VERSION
|
||||
|
||||
|
||||
class TestWorkflowRunArchiverInit:
|
||||
def test_start_from_without_end_before_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(start_from=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_end_before_without_start_from_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from and end_before must be provided together"):
|
||||
WorkflowRunArchiver(end_before=datetime.datetime(2025, 1, 1))
|
||||
|
||||
def test_start_equals_end_raises(self):
|
||||
ts = datetime.datetime(2025, 1, 1)
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(start_from=ts, end_before=ts)
|
||||
|
||||
def test_start_after_end_raises(self):
|
||||
with pytest.raises(ValueError, match="start_from must be earlier than end_before"):
|
||||
WorkflowRunArchiver(
|
||||
start_from=datetime.datetime(2025, 6, 1),
|
||||
end_before=datetime.datetime(2025, 1, 1),
|
||||
)
|
||||
|
||||
def test_workers_zero_raises(self):
|
||||
with pytest.raises(ValueError, match="workers must be at least 1"):
|
||||
WorkflowRunArchiver(workers=0)
|
||||
|
||||
def test_valid_init_defaults(self):
|
||||
archiver = WorkflowRunArchiver(days=30, batch_size=50)
|
||||
assert archiver.days == 30
|
||||
assert archiver.batch_size == 50
|
||||
assert archiver.dry_run is False
|
||||
assert archiver.delete_after_archive is False
|
||||
assert archiver.start_from is None
|
||||
|
||||
def test_valid_init_with_time_range(self):
|
||||
start = datetime.datetime(2025, 1, 1)
|
||||
end = datetime.datetime(2025, 6, 1)
|
||||
archiver = WorkflowRunArchiver(start_from=start, end_before=end, workers=2)
|
||||
assert archiver.start_from is not None
|
||||
assert archiver.end_before is not None
|
||||
assert archiver.workers == 2
|
||||
|
||||
|
||||
class TestBuildArchiveBundle:
|
||||
def test_bundle_contains_manifest_and_all_tables(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
manifest_data = json.dumps({"schema_version": ARCHIVE_SCHEMA_VERSION}).encode("utf-8")
|
||||
table_payloads = dict.fromkeys(archiver.ARCHIVED_TABLES, b"")
|
||||
|
||||
bundle_bytes = archiver._build_archive_bundle(manifest_data, table_payloads)
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(bundle_bytes), "r") as zf:
|
||||
names = set(zf.namelist())
|
||||
assert "manifest.json" in names
|
||||
for table in archiver.ARCHIVED_TABLES:
|
||||
assert f"{table}.jsonl" in names, f"Missing {table}.jsonl in bundle"
|
||||
|
||||
def test_bundle_missing_table_payload_raises(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
manifest_data = b"{}"
|
||||
incomplete_payloads = {archiver.ARCHIVED_TABLES[0]: b"data"}
|
||||
|
||||
with pytest.raises(ValueError, match="Missing archive payload"):
|
||||
archiver._build_archive_bundle(manifest_data, incomplete_payloads)
|
||||
|
||||
|
||||
class TestGenerateManifest:
|
||||
def test_manifest_structure(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import TableStats
|
||||
|
||||
run = MagicMock()
|
||||
run.id = str(uuid.uuid4())
|
||||
run.tenant_id = str(uuid.uuid4())
|
||||
run.app_id = str(uuid.uuid4())
|
||||
run.workflow_id = str(uuid.uuid4())
|
||||
run.created_at = datetime.datetime(2025, 3, 15, 10, 0, 0)
|
||||
|
||||
stats = [
|
||||
TableStats(table_name="workflow_runs", row_count=1, checksum="abc123", size_bytes=512),
|
||||
TableStats(table_name="workflow_app_logs", row_count=2, checksum="def456", size_bytes=1024),
|
||||
]
|
||||
|
||||
manifest = archiver._generate_manifest(run, stats)
|
||||
|
||||
assert manifest["schema_version"] == ARCHIVE_SCHEMA_VERSION
|
||||
assert manifest["workflow_run_id"] == run.id
|
||||
assert manifest["tenant_id"] == run.tenant_id
|
||||
assert manifest["app_id"] == run.app_id
|
||||
assert "tables" in manifest
|
||||
assert manifest["tables"]["workflow_runs"]["row_count"] == 1
|
||||
assert manifest["tables"]["workflow_runs"]["checksum"] == "abc123"
|
||||
assert manifest["tables"]["workflow_app_logs"]["row_count"] == 2
|
||||
|
||||
|
||||
class TestFilterPaidTenants:
|
||||
def test_all_tenants_paid_when_billing_disabled(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
tenant_ids = {"t1", "t2", "t3"}
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
result = archiver._filter_paid_tenants(tenant_ids)
|
||||
|
||||
assert result == tenant_ids
|
||||
|
||||
def test_empty_tenants_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = True
|
||||
result = archiver._filter_paid_tenants(set())
|
||||
|
||||
assert result == set()
|
||||
|
||||
def test_only_paid_plans_returned(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
mock_bulk = {
|
||||
"t1": {"plan": "professional"},
|
||||
"t2": {"plan": "sandbox"},
|
||||
"t3": {"plan": "team"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.return_value = mock_bulk
|
||||
result = archiver._filter_paid_tenants({"t1", "t2", "t3"})
|
||||
|
||||
assert "t1" in result
|
||||
assert "t3" in result
|
||||
assert "t2" not in result
|
||||
|
||||
def test_billing_api_failure_returns_empty(self):
|
||||
archiver = WorkflowRunArchiver(days=90)
|
||||
|
||||
with (
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config") as cfg,
|
||||
patch("services.retention.workflow_run.archive_paid_plan_workflow_run.BillingService") as billing,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
billing.get_plan_bulk_with_cache.side_effect = RuntimeError("API down")
|
||||
result = archiver._filter_paid_tenants({"t1"})
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
class TestDryRunArchive:
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
def test_dry_run_does_not_call_storage(self, mock_get_storage, flask_req_ctx):
|
||||
archiver = WorkflowRunArchiver(days=90, dry_run=True)
|
||||
|
||||
with patch.object(archiver, "_get_runs_batch", return_value=[]):
|
||||
summary = archiver.run()
|
||||
|
||||
mock_get_storage.assert_not_called()
|
||||
assert isinstance(summary, ArchiveSummary)
|
||||
assert summary.runs_failed == 0
|
||||
@@ -1,7 +1,4 @@
|
||||
"""
|
||||
Additional tests to improve coverage for low-coverage modules in controllers/console/app.
|
||||
Target: increase coverage for files with <75% coverage.
|
||||
"""
|
||||
"""Testcontainers integration tests for controllers/console/app endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -70,26 +67,12 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, _query, _args):
|
||||
return self._rows
|
||||
|
||||
|
||||
# ========== Completion Tests ==========
|
||||
class TestCompletionEndpoints:
|
||||
"""Tests for completion API endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_completion_create_payload(self):
|
||||
"""Test completion creation payload."""
|
||||
payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={})
|
||||
assert payload.inputs == {"prompt": "test"}
|
||||
|
||||
@@ -209,7 +192,9 @@ class TestCompletionEndpoints:
|
||||
|
||||
|
||||
class TestAppEndpoints:
|
||||
"""Tests for app endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch):
|
||||
api = app_module.AppApi()
|
||||
@@ -250,12 +235,12 @@ class TestAppEndpoints:
|
||||
)
|
||||
|
||||
|
||||
# ========== OpsTrace Tests ==========
|
||||
class TestOpsTraceEndpoints:
|
||||
"""Tests for ops_trace endpoint."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_ops_trace_query_basic(self):
|
||||
"""Test ops_trace query."""
|
||||
query = TraceProviderQuery(tracing_provider="langfuse")
|
||||
assert query.tracing_provider == "langfuse"
|
||||
|
||||
@@ -310,12 +295,12 @@ class TestOpsTraceEndpoints:
|
||||
method(app_id="app-1")
|
||||
|
||||
|
||||
# ========== Site Tests ==========
|
||||
class TestSiteEndpoints:
|
||||
"""Tests for site endpoint."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_site_response_structure(self):
|
||||
"""Test site response structure."""
|
||||
payload = AppSiteUpdatePayload(title="My Site", description="Test site")
|
||||
assert payload.title == "My Site"
|
||||
|
||||
@@ -369,27 +354,22 @@ class TestSiteEndpoints:
|
||||
assert result is site
|
||||
|
||||
|
||||
# ========== Workflow Tests ==========
|
||||
class TestWorkflowEndpoints:
|
||||
"""Tests for workflow endpoints."""
|
||||
|
||||
def test_workflow_copy_payload(self):
|
||||
"""Test workflow copy payload."""
|
||||
payload = SyncDraftWorkflowPayload(graph={}, features={})
|
||||
assert payload.graph == {}
|
||||
|
||||
def test_workflow_mode_query(self):
|
||||
"""Test workflow mode query."""
|
||||
payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi")
|
||||
assert payload.query == "hi"
|
||||
|
||||
|
||||
# ========== Workflow App Log Tests ==========
|
||||
class TestWorkflowAppLogEndpoints:
|
||||
"""Tests for workflow app log endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_app_log_query(self):
|
||||
"""Test workflow app log query."""
|
||||
query = WorkflowAppLogQuery(keyword="test", page=1, limit=20)
|
||||
assert query.keyword == "test"
|
||||
|
||||
@@ -427,12 +407,12 @@ class TestWorkflowAppLogEndpoints:
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Draft Variable Tests ==========
|
||||
class TestWorkflowDraftVariableEndpoints:
|
||||
"""Tests for workflow draft variable endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_variable_creation(self):
|
||||
"""Test workflow variable creation."""
|
||||
payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test")
|
||||
assert payload.name == "var1"
|
||||
|
||||
@@ -472,12 +452,12 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
|
||||
# ========== Workflow Statistic Tests ==========
|
||||
class TestWorkflowStatisticEndpoints:
|
||||
"""Tests for workflow statistic endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_workflow_statistic_time_range(self):
|
||||
"""Test workflow statistic time range query."""
|
||||
query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31")
|
||||
assert query.start == "2024-01-01"
|
||||
|
||||
@@ -541,12 +521,12 @@ class TestWorkflowStatisticEndpoints:
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
|
||||
# ========== Workflow Trigger Tests ==========
|
||||
class TestWorkflowTriggerEndpoints:
|
||||
"""Tests for workflow trigger endpoints."""
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_webhook_trigger_payload(self):
|
||||
"""Test webhook trigger payload."""
|
||||
payload = Parser(node_id="node-1")
|
||||
assert payload.node_id == "node-1"
|
||||
|
||||
@@ -578,22 +558,13 @@ class TestWorkflowTriggerEndpoints:
|
||||
assert result is trigger
|
||||
|
||||
|
||||
# ========== Wraps Tests ==========
|
||||
class TestWrapsEndpoints:
|
||||
"""Tests for wraps utility functions."""
|
||||
|
||||
def test_get_app_model_context(self):
|
||||
"""Test get_app_model wrapper context."""
|
||||
# These are decorator functions, so we test their availability
|
||||
assert hasattr(wraps_module, "get_app_model")
|
||||
|
||||
|
||||
# ========== MCP Server Tests ==========
|
||||
class TestMCPServerEndpoints:
|
||||
"""Tests for MCP server endpoints."""
|
||||
|
||||
def test_mcp_server_connection(self):
|
||||
"""Test MCP server connection."""
|
||||
payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"})
|
||||
assert payload.parameters["url"] == "http://localhost:3000"
|
||||
|
||||
@@ -602,22 +573,14 @@ class TestMCPServerEndpoints:
|
||||
assert payload.status == "active"
|
||||
|
||||
|
||||
# ========== Error Handling Tests ==========
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling in various endpoints."""
|
||||
|
||||
def test_annotation_list_query_validation(self):
|
||||
"""Test annotation list query validation."""
|
||||
with pytest.raises(ValueError):
|
||||
annotation_module.AnnotationListQuery(page=0)
|
||||
|
||||
|
||||
# ========== Integration-like Tests ==========
|
||||
class TestPayloadIntegration:
|
||||
"""Integration tests for payload handling."""
|
||||
|
||||
def test_multiple_payload_types(self):
|
||||
"""Test handling of multiple payload types."""
|
||||
payloads = [
|
||||
annotation_module.AnnotationReplyPayload(
|
||||
score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small"
|
||||
@@ -0,0 +1,142 @@
|
||||
"""Testcontainers integration tests for controllers.console.app.app_import endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
class TestAppImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
class TestAppImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
@@ -1,6 +1,12 @@
|
||||
"""Testcontainers integration tests for rag_pipeline controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
@@ -9,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import (
|
||||
PipelineTemplateListApi,
|
||||
PublishCustomizedPipelineTemplateApi,
|
||||
)
|
||||
from models.dataset import PipelineCustomizedTemplate
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@@ -18,6 +25,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestPipelineTemplateListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -38,6 +49,10 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
|
||||
class TestPipelineTemplateDetailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -99,6 +114,10 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -136,35 +155,29 @@ class TestCustomizedPipelineTemplateApi:
|
||||
delete_mock.assert_called_once_with("tpl-1")
|
||||
assert response == 200
|
||||
|
||||
def test_post_success(self, app):
|
||||
def test_post_success(self, app, db_session_with_containers: Session):
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
template = MagicMock()
|
||||
template.yaml_content = "yaml-data"
|
||||
tenant_id = str(uuid4())
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id=tenant_id,
|
||||
name="Test Template",
|
||||
description="Test",
|
||||
chunk_structure="hierarchical",
|
||||
icon={"icon": "📘"},
|
||||
position=0,
|
||||
yaml_content="yaml-data",
|
||||
install_count=0,
|
||||
language="en-US",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(template)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = template
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "tpl-1")
|
||||
with app.test_request_context("/"):
|
||||
response, status = method(api, template.id)
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": "yaml-data"}
|
||||
@@ -173,32 +186,16 @@ class TestCustomizedPipelineTemplateApi:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "tpl-1")
|
||||
method(api, str(uuid4()))
|
||||
|
||||
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_datasets controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -19,6 +23,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestCreateRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _valid_payload(self):
|
||||
return {"yaml_content": "name: test"}
|
||||
|
||||
@@ -33,13 +41,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.return_value = import_info
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -47,14 +48,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -93,13 +86,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
mock_service = MagicMock()
|
||||
mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError()
|
||||
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__.return_value = MagicMock()
|
||||
mock_session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -107,14 +93,6 @@ class TestCreateRagPipelineDatasetApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session",
|
||||
return_value=mock_session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService",
|
||||
return_value=mock_service,
|
||||
@@ -143,6 +121,10 @@ class TestCreateRagPipelineDatasetApi:
|
||||
|
||||
|
||||
class TestCreateEmptyRagPipelineDatasetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_success(self, app):
|
||||
api = CreateEmptyRagPipelineDatasetApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_import controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.rag_pipeline.rag_pipeline_import import (
|
||||
RagPipelineExportApi,
|
||||
@@ -18,6 +24,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestRagPipelineImportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _payload(self, mode="create"):
|
||||
return {
|
||||
"mode": mode,
|
||||
@@ -30,7 +40,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = "completed"
|
||||
@@ -39,13 +48,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -53,14 +55,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -76,7 +70,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.FAILED
|
||||
@@ -85,13 +78,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -99,14 +85,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -122,7 +100,6 @@ class TestRagPipelineImportApi:
|
||||
method = unwrap(api.post)
|
||||
|
||||
payload = self._payload()
|
||||
|
||||
user = MagicMock()
|
||||
result = MagicMock()
|
||||
result.status = ImportStatus.PENDING
|
||||
@@ -131,13 +108,6 @@ class TestRagPipelineImportApi:
|
||||
service = MagicMock()
|
||||
service.import_rag_pipeline.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
@@ -145,14 +115,6 @@ class TestRagPipelineImportApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -165,6 +127,10 @@ class TestRagPipelineImportApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_confirm_success(self, app):
|
||||
api = RagPipelineImportConfirmApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -177,27 +143,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -220,27 +171,12 @@ class TestRagPipelineImportConfirmApi:
|
||||
service = MagicMock()
|
||||
service.confirm_import.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -253,6 +189,10 @@ class TestRagPipelineImportConfirmApi:
|
||||
|
||||
|
||||
class TestRagPipelineImportCheckDependenciesApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = RagPipelineImportCheckDependenciesApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -264,23 +204,8 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
service = MagicMock()
|
||||
service.check_dependencies.return_value = result
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -293,6 +218,10 @@ class TestRagPipelineImportCheckDependenciesApi:
|
||||
|
||||
|
||||
class TestRagPipelineExportApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_with_include_secret(self, app):
|
||||
api = RagPipelineExportApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -301,23 +230,8 @@ class TestRagPipelineExportApi:
|
||||
service = MagicMock()
|
||||
service.export_rag_pipeline_dsl.return_value = {"yaml": "data"}
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = MagicMock()
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
app.test_request_context("/?include_secret=true"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService",
|
||||
return_value=service,
|
||||
@@ -1,7 +1,13 @@
|
||||
"""Testcontainers integration tests for rag_pipeline_workflow controller endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound
|
||||
|
||||
import services
|
||||
@@ -38,6 +44,10 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestDraftWorkflowApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_draft_success(self, app):
|
||||
api = DraftRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -200,6 +210,10 @@ class TestDraftWorkflowApi:
|
||||
|
||||
|
||||
class TestDraftRunNodes:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_iteration_node_success(self, app):
|
||||
api = RagPipelineDraftRunIterationNodeApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -275,6 +289,10 @@ class TestDraftRunNodes:
|
||||
|
||||
|
||||
class TestPipelineRunApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_draft_run_success(self, app):
|
||||
api = DraftRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -337,6 +355,10 @@ class TestPipelineRunApis:
|
||||
|
||||
|
||||
class TestDraftNodeRun:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_execution_not_found(self, app):
|
||||
api = RagPipelineDraftNodeRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -364,45 +386,43 @@ class TestDraftNodeRun:
|
||||
|
||||
|
||||
class TestPublishedPipelineApis:
|
||||
def test_publish_success(self, app):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_publish_success(self, app, db_session_with_containers: Session):
|
||||
from models.dataset import Pipeline
|
||||
|
||||
api = PublishedRagPipelineApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
pipeline = MagicMock()
|
||||
tenant_id = str(uuid4())
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name="test-pipeline",
|
||||
description="test",
|
||||
created_by=str(uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(pipeline)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
workflow = MagicMock(
|
||||
id="w1",
|
||||
id=str(uuid4()),
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.merge.return_value = pipeline
|
||||
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
service = MagicMock()
|
||||
service.publish_workflow.return_value = workflow
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -415,6 +435,10 @@ class TestPublishedPipelineApis:
|
||||
|
||||
|
||||
class TestMiscApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_task_stop(self, app):
|
||||
api = RagPipelineTaskStopApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -471,6 +495,10 @@ class TestMiscApis:
|
||||
|
||||
|
||||
class TestPublishedRagPipelineRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_published_run_success(self, app):
|
||||
api = PublishedRagPipelineRunApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -536,6 +564,10 @@ class TestPublishedRagPipelineRunApi:
|
||||
|
||||
|
||||
class TestDefaultBlockConfigApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_block_config_success(self, app):
|
||||
api = DefaultRagPipelineBlockConfigApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -567,6 +599,10 @@ class TestDefaultBlockConfigApi:
|
||||
|
||||
|
||||
class TestPublishedAllRagPipelineApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_published_workflows_success(self, app):
|
||||
api = PublishedAllRagPipelineApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -577,28 +613,12 @@ class TestPublishedAllRagPipelineApi:
|
||||
service = MagicMock()
|
||||
service.get_all_published_workflow.return_value = ([{"id": "w1"}], False)
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -628,6 +648,10 @@ class TestPublishedAllRagPipelineApi:
|
||||
|
||||
|
||||
class TestRagPipelineByIdApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_patch_success(self, app):
|
||||
api = RagPipelineByIdApi()
|
||||
method = unwrap(api.patch)
|
||||
@@ -640,14 +664,6 @@ class TestRagPipelineByIdApi:
|
||||
service = MagicMock()
|
||||
service.update_workflow.return_value = workflow
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
payload = {"marked_name": "test"}
|
||||
|
||||
with (
|
||||
@@ -657,14 +673,6 @@ class TestRagPipelineByIdApi:
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant",
|
||||
return_value=(user, "t"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService",
|
||||
return_value=service,
|
||||
@@ -700,24 +708,8 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
workflow_service = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_db.engine = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db",
|
||||
fake_db,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session",
|
||||
return_value=session_ctx,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService",
|
||||
return_value=workflow_service,
|
||||
@@ -725,12 +717,7 @@ class TestRagPipelineByIdApi:
|
||||
):
|
||||
result = method(api, pipeline, "old-workflow")
|
||||
|
||||
workflow_service.delete_workflow.assert_called_once_with(
|
||||
session=session,
|
||||
workflow_id="old-workflow",
|
||||
tenant_id="t1",
|
||||
)
|
||||
session.commit.assert_called_once()
|
||||
workflow_service.delete_workflow.assert_called_once()
|
||||
assert result == (None, 204)
|
||||
|
||||
def test_delete_active_workflow_rejected(self, app):
|
||||
@@ -745,6 +732,10 @@ class TestRagPipelineByIdApi:
|
||||
|
||||
|
||||
class TestRagPipelineWorkflowLastRunApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_last_run_success(self, app):
|
||||
api = RagPipelineWorkflowLastRunApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -788,6 +779,10 @@ class TestRagPipelineWorkflowLastRunApi:
|
||||
|
||||
|
||||
class TestRagPipelineDatasourceVariableApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_set_datasource_variables_success(self, app):
|
||||
api = RagPipelineDatasourceVariableApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for controllers.console.datasets.data_source endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -46,6 +50,10 @@ def mock_engine():
|
||||
|
||||
|
||||
class TestDataSourceApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -179,6 +187,10 @@ class TestDataSourceApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_credential_not_found(self, app, patch_tenant):
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -310,6 +322,10 @@ class TestDataSourceNotionListApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_preview_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -364,6 +380,10 @@ class TestDataSourceNotionApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -403,6 +423,10 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, patch_tenant):
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Testcontainers integration tests for controllers.console.explore.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import controllers.console.explore.conversation as conversation_module
|
||||
@@ -48,24 +51,12 @@ def user():
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db_and_session():
|
||||
with (
|
||||
patch.object(
|
||||
conversation_module,
|
||||
"db",
|
||||
MagicMock(session=MagicMock(), engine=MagicMock()),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.conversation.Session",
|
||||
MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestConversationListApi:
|
||||
def test_get_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -90,7 +81,7 @@ class TestConversationListApi:
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
|
||||
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
|
||||
def test_last_conversation_not_exists(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -106,7 +97,7 @@ class TestConversationListApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app)
|
||||
|
||||
def test_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -116,7 +107,11 @@ class TestConversationListApi:
|
||||
|
||||
|
||||
class TestConversationApi:
|
||||
def test_delete_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_delete_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -134,7 +129,7 @@ class TestConversationApi:
|
||||
assert status == 204
|
||||
assert body["result"] == "success"
|
||||
|
||||
def test_delete_not_found(self, app: Flask, chat_app, user):
|
||||
def test_delete_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -150,7 +145,7 @@ class TestConversationApi:
|
||||
with pytest.raises(NotFound):
|
||||
method(chat_app, "cid")
|
||||
|
||||
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
|
||||
def test_delete_wrong_app_mode(self, app, non_chat_app):
|
||||
api = conversation_module.ConversationApi()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
@@ -160,7 +155,11 @@ class TestConversationApi:
|
||||
|
||||
|
||||
class TestConversationRenameApi:
|
||||
def test_rename_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_rename_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -179,7 +178,7 @@ class TestConversationRenameApi:
|
||||
|
||||
assert result["id"] == "cid"
|
||||
|
||||
def test_rename_not_found(self, app: Flask, chat_app, user):
|
||||
def test_rename_not_found(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationRenameApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
@@ -197,7 +196,11 @@ class TestConversationRenameApi:
|
||||
|
||||
|
||||
class TestConversationPinApi:
|
||||
def test_pin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_pin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -215,7 +218,11 @@ class TestConversationPinApi:
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_unpin_success(self, app: Flask, chat_app, user):
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_unpin_success(self, app, chat_app, user):
|
||||
api = conversation_module.ConversationUnPinApi()
|
||||
method = unwrap(api.patch)
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
"""Testcontainers integration tests for controllers.console.workspace.tool_providers endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace.tool_providers import (
|
||||
@@ -31,7 +33,6 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolOAuthCustomClient,
|
||||
ToolPluginOAuthApi,
|
||||
ToolProviderListApi,
|
||||
ToolProviderMCPApi,
|
||||
ToolWorkflowListApi,
|
||||
ToolWorkflowProviderCreateApi,
|
||||
ToolWorkflowProviderDeleteApi,
|
||||
@@ -39,8 +40,6 @@ from controllers.console.workspace.tool_providers import (
|
||||
ToolWorkflowProviderUpdateApi,
|
||||
is_valid_url,
|
||||
)
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
@@ -61,17 +60,8 @@ def _mock_user_tenant():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
def client(flask_app_with_containers):
|
||||
return flask_app_with_containers.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
@@ -152,10 +142,14 @@ class TestUtils:
|
||||
assert not is_valid_url("")
|
||||
assert not is_valid_url("ftp://example.com")
|
||||
assert not is_valid_url("not-a-url")
|
||||
assert not is_valid_url(None)
|
||||
assert not is_valid_url(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestToolProviderListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app):
|
||||
api = ToolProviderListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -175,6 +169,10 @@ class TestToolProviderListApi:
|
||||
|
||||
|
||||
class TestBuiltinProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_tools(self, app):
|
||||
api = ToolBuiltinProviderListToolsApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -379,6 +377,10 @@ class TestBuiltinProviderApis:
|
||||
|
||||
|
||||
class TestApiProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_add(self, app):
|
||||
api = ToolApiProviderAddApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -502,6 +504,10 @@ class TestApiProviderApis:
|
||||
|
||||
|
||||
class TestWorkflowApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create(self, app):
|
||||
api = ToolWorkflowProviderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -587,6 +593,10 @@ class TestWorkflowApis:
|
||||
|
||||
|
||||
class TestLists:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_builtin_list(self, app):
|
||||
api = ToolBuiltinListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -649,6 +659,10 @@ class TestLists:
|
||||
|
||||
|
||||
class TestLabels:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_labels(self, app):
|
||||
api = ToolLabelsApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -664,6 +678,10 @@ class TestLabels:
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_no_client(self, app):
|
||||
api = ToolPluginOAuthApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -692,6 +710,10 @@ class TestOAuth:
|
||||
|
||||
|
||||
class TestOAuthCustomClient:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_save_custom_client(self, app):
|
||||
api = ToolOAuthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
@@ -1,3 +1,7 @@
|
||||
"""Testcontainers integration tests for controllers.console.workspace.trigger_providers endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -40,6 +44,10 @@ def mock_user():
|
||||
|
||||
|
||||
class TestTriggerProviderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_icon_success(self, app):
|
||||
api = TriggerProviderIconApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -84,6 +92,10 @@ class TestTriggerProviderApis:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_list_success(self, app):
|
||||
api = TriggerSubscriptionListApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -115,6 +127,10 @@ class TestTriggerSubscriptionListApi:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionBuilderApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_create_builder(self, app):
|
||||
api = TriggerSubscriptionBuilderCreateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -219,6 +235,10 @@ class TestTriggerSubscriptionBuilderApis:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionCrud:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_update_rename_only(self, app):
|
||||
api = TriggerSubscriptionUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -321,6 +341,10 @@ class TestTriggerSubscriptionCrud:
|
||||
|
||||
|
||||
class TestTriggerOAuthApis:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_oauth_authorize_success(self, app):
|
||||
api = TriggerOAuthAuthorizeApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -455,6 +479,10 @@ class TestTriggerOAuthApis:
|
||||
|
||||
|
||||
class TestTriggerOAuthClientManageApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_client(self, app):
|
||||
api = TriggerOAuthClientManageApi()
|
||||
method = unwrap(api.get)
|
||||
@@ -527,6 +555,10 @@ class TestTriggerOAuthClientManageApi:
|
||||
|
||||
|
||||
class TestTriggerSubscriptionVerifyApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_verify_success(self, app):
|
||||
api = TriggerSubscriptionVerifyApi()
|
||||
method = unwrap(api.post)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Testcontainers integration tests for plugin_permission_required decorator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import Tenant, TenantPluginPermission, TenantStatus
|
||||
|
||||
|
||||
def _create_tenant(db_session: Session) -> Tenant:
|
||||
tenant = Tenant(name="test-tenant", status=TenantStatus.NORMAL, plan="basic")
|
||||
db_session.add(tenant)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
return tenant
|
||||
|
||||
|
||||
def _create_permission(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
install: TenantPluginPermission.InstallPermission = TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug: TenantPluginPermission.DebugPermission = TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
) -> TenantPluginPermission:
|
||||
perm = TenantPluginPermission(
|
||||
tenant_id=tenant_id,
|
||||
install_permission=install,
|
||||
debug_permission=debug,
|
||||
)
|
||||
db_session.add(perm)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
return perm
|
||||
|
||||
|
||||
class TestPluginPermissionRequired:
|
||||
def test_allows_without_permission(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
def test_install_nobody_forbidden(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_install_admin_requires_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_install_admin_allows_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
def test_debug_nobody_forbidden(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_debug_admin_requires_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
def test_debug_admin_allows_admin(self, db_session_with_containers: Session):
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
_create_permission(
|
||||
db_session_with_containers,
|
||||
tenant.id,
|
||||
install=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
|
||||
with patch(
|
||||
"controllers.console.workspace.current_account_with_tenant",
|
||||
return_value=(user, tenant.id),
|
||||
):
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
@@ -1,5 +1,10 @@
|
||||
"""Testcontainers integration tests for controllers.mcp.mcp endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Response
|
||||
@@ -14,24 +19,6 @@ def unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_db():
|
||||
module.db = types.SimpleNamespace(engine=object())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_session():
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_session(fake_session):
|
||||
module.Session = MagicMock(return_value=fake_session)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_mcp_ns():
|
||||
fake_ns = types.SimpleNamespace()
|
||||
@@ -44,8 +31,13 @@ def fake_payload(data):
|
||||
module.mcp_ns.payload = data
|
||||
|
||||
|
||||
_TENANT_ID = str(uuid4())
|
||||
_APP_ID = str(uuid4())
|
||||
_SERVER_ID = str(uuid4())
|
||||
|
||||
|
||||
class DummyServer:
|
||||
def __init__(self, status, app_id="app-1", tenant_id="tenant-1", server_id="srv-1"):
|
||||
def __init__(self, status, app_id=_APP_ID, tenant_id=_TENANT_ID, server_id=_SERVER_ID):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
self.tenant_id = tenant_id
|
||||
@@ -54,8 +46,8 @@ class DummyServer:
|
||||
|
||||
class DummyApp:
|
||||
def __init__(self, mode, workflow=None, app_model_config=None):
|
||||
self.id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.id = _APP_ID
|
||||
self.tenant_id = _TENANT_ID
|
||||
self.mode = mode
|
||||
self.workflow = workflow
|
||||
self.app_model_config = app_model_config
|
||||
@@ -76,6 +68,7 @@ class DummyResult:
|
||||
return {"jsonrpc": "2.0", "result": "ok", "id": 1}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_req_ctx_with_containers")
|
||||
class TestMCPAppApi:
|
||||
@patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True)
|
||||
def test_success_request(self, mock_handle):
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Unit tests for controllers.web.conversation endpoints."""
|
||||
"""Testcontainers integration tests for controllers.web.conversation endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -7,7 +7,6 @@ from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web.conversation import (
|
||||
@@ -33,18 +32,18 @@ def _end_user() -> SimpleNamespace:
|
||||
return SimpleNamespace(id="eu-1")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationListApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationListApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context("/conversations"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationListApi().get(_completion_app(), _end_user())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pagination_by_last_id")
|
||||
@patch("controllers.web.conversation.db")
|
||||
def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None:
|
||||
def test_happy_path(self, mock_paginate: MagicMock, app) -> None:
|
||||
conv_id = str(uuid4())
|
||||
conv = SimpleNamespace(
|
||||
id=conv_id,
|
||||
@@ -56,34 +55,26 @@ class TestConversationListApi:
|
||||
updated_at=1700000000,
|
||||
)
|
||||
mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv])
|
||||
mock_db.engine = "engine"
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/conversations?limit=20"),
|
||||
patch("controllers.web.conversation.Session", return_value=session_ctx),
|
||||
):
|
||||
with app.test_request_context("/conversations?limit=20"):
|
||||
result = ConversationListApi().get(_chat_app(), _end_user())
|
||||
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationApi (delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationApi().delete(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete")
|
||||
def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
def test_delete_success(self, mock_delete: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
@@ -92,25 +83,26 @@ class TestConversationApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError())
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None:
|
||||
def test_delete_not_found(self, mock_delete: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}"):
|
||||
with pytest.raises(NotFound, match="Conversation Not Exists"):
|
||||
ConversationApi().delete(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationRenameApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationRenameApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationRenameApi().post(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.ConversationService.rename")
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "New Name", "auto_generate": False}
|
||||
conv = SimpleNamespace(
|
||||
@@ -134,7 +126,7 @@ class TestConversationRenameApi:
|
||||
side_effect=ConversationNotExistsError(),
|
||||
)
|
||||
@patch("controllers.web.conversation.web_ns")
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None:
|
||||
def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
mock_ns.payload = {"name": "X", "auto_generate": False}
|
||||
|
||||
@@ -143,17 +135,18 @@ class TestConversationRenameApi:
|
||||
ConversationRenameApi().post(_chat_app(), _end_user(), c_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ConversationPinApi / ConversationUnPinApi
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestConversationPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin")
|
||||
def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
def test_pin_success(self, mock_pin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
@@ -161,7 +154,7 @@ class TestConversationPinApi:
|
||||
assert result["result"] == "success"
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError())
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None:
|
||||
def test_pin_not_found(self, mock_pin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"):
|
||||
with pytest.raises(NotFound):
|
||||
@@ -169,13 +162,17 @@ class TestConversationPinApi:
|
||||
|
||||
|
||||
class TestConversationUnPinApi:
|
||||
def test_non_chat_mode_raises(self, app: Flask) -> None:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_non_chat_mode_raises(self, app) -> None:
|
||||
with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"):
|
||||
with pytest.raises(NotChatAppError):
|
||||
ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4())
|
||||
|
||||
@patch("controllers.web.conversation.WebConversationService.unpin")
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None:
|
||||
def test_unpin_success(self, mock_unpin: MagicMock, app) -> None:
|
||||
c_id = uuid4()
|
||||
with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"):
|
||||
result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""Testcontainers integration tests for controllers.web.forgot_password endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.web.forgot_password import (
|
||||
ForgotPasswordCheckApi,
|
||||
@@ -12,13 +15,6 @@ from controllers.web.forgot_password import (
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
flask_app = Flask(__name__)
|
||||
flask_app.config["TESTING"] = True
|
||||
return flask_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_wraps():
|
||||
wraps_features = SimpleNamespace(enable_email_password_login=True)
|
||||
@@ -33,6 +29,10 @@ def _patch_wraps():
|
||||
|
||||
|
||||
class TestForgotPasswordSendEmailApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.send_reset_password_email")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@@ -69,6 +69,10 @@ class TestForgotPasswordSendEmailApi:
|
||||
|
||||
|
||||
class TestForgotPasswordCheckApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit")
|
||||
@patch("controllers.web.forgot_password.AccountService.generate_reset_password_token")
|
||||
@patch("controllers.web.forgot_password.AccountService.revoke_reset_password_token")
|
||||
@@ -143,6 +147,10 @@ class TestForgotPasswordCheckApi:
|
||||
|
||||
|
||||
class TestForgotPasswordResetApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
@patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account")
|
||||
@patch("controllers.web.forgot_password.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.web.forgot_password.Session")
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
"""Testcontainers integration tests for controllers.web.wraps — JWT auth decorator and validation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
@@ -18,12 +19,8 @@ from controllers.web.wraps import (
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_webapp_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateWebappToken:
|
||||
def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None:
|
||||
"""When both flags are true, a non-webapp source must raise."""
|
||||
decoded = {"token_source": "other"}
|
||||
with pytest.raises(WebAppAuthRequiredError):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
@@ -38,7 +35,6 @@ class TestValidateWebappToken:
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True)
|
||||
|
||||
def test_public_app_rejects_webapp_source(self) -> None:
|
||||
"""When auth is not required, a webapp-sourced token must be rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
@@ -52,18 +48,13 @@ class TestValidateWebappToken:
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False)
|
||||
|
||||
def test_system_enabled_but_app_public(self) -> None:
|
||||
"""system_webapp_auth_enabled=True but app is public — webapp source rejected."""
|
||||
decoded = {"token_source": "webapp"}
|
||||
with pytest.raises(Unauthorized):
|
||||
_validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_user_accessibility
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestValidateUserAccessibility:
|
||||
def test_skips_when_auth_disabled(self) -> None:
|
||||
"""No checks when system or app auth is disabled."""
|
||||
_validate_user_accessibility(
|
||||
decoded={},
|
||||
app_code="code",
|
||||
@@ -123,7 +114,6 @@ class TestValidateUserAccessibility:
|
||||
def test_external_auth_type_checks_sso_update_time(
|
||||
self, mock_perm_check: MagicMock, mock_sso_time: MagicMock
|
||||
) -> None:
|
||||
# granted_at is before SSO update time → denied
|
||||
mock_sso_time.return_value = datetime.now(UTC)
|
||||
old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted}
|
||||
@@ -164,7 +154,6 @@ class TestValidateUserAccessibility:
|
||||
recent_granted = int(datetime.now(UTC).timestamp())
|
||||
decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted}
|
||||
settings = SimpleNamespace(access_mode="public")
|
||||
# Should not raise
|
||||
_validate_user_accessibility(
|
||||
decoded=decoded,
|
||||
app_code="code",
|
||||
@@ -191,10 +180,49 @@ class TestValidateUserAccessibility:
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# decode_jwt_token
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestDecodeJwtToken:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers):
|
||||
return flask_app_with_containers
|
||||
|
||||
def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True):
|
||||
from models.model import App, AppMode, CustomizeTokenStrategy, EndUser, Site
|
||||
|
||||
tenant_id = str(uuid4())
|
||||
app_model = App(
|
||||
tenant_id=tenant_id,
|
||||
mode=AppMode.CHAT.value,
|
||||
name="test-app",
|
||||
enable_site=enable_site,
|
||||
enable_api=True,
|
||||
)
|
||||
db_session.add(app_model)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
site = Site(
|
||||
app_id=app_model.id,
|
||||
title="test-site",
|
||||
default_language="en-US",
|
||||
customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW,
|
||||
code="code1",
|
||||
)
|
||||
db_session.add(site)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
end_user = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_model.id,
|
||||
type="browser",
|
||||
session_id="sess-1",
|
||||
)
|
||||
db_session.add(end_user)
|
||||
db_session.commit()
|
||||
db_session.expire_all()
|
||||
|
||||
return app_model, site, end_user
|
||||
|
||||
@patch("controllers.web.wraps._validate_user_accessibility")
|
||||
@patch("controllers.web.wraps._validate_webapp_token")
|
||||
@patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id")
|
||||
@@ -202,10 +230,8 @@ class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_happy_path(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
@@ -213,40 +239,28 @@ class TestDecodeJwtToken:
|
||||
mock_access_mode: MagicMock,
|
||||
mock_validate_token: MagicMock,
|
||||
mock_validate_user: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
# Configure session mock to return correct objects via scalar()
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
result_app, result_user = decode_jwt_token()
|
||||
|
||||
assert result_app.id == "app-1"
|
||||
assert result_user.id == "eu-1"
|
||||
assert result_app.id == app_model.id
|
||||
assert result_user.id == end_user.id
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
def test_missing_token_raises_unauthorized(
|
||||
self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask
|
||||
) -> None:
|
||||
def test_missing_token_raises_unauthorized(self, mock_extract: MagicMock, mock_features: MagicMock, app) -> None:
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_extract.return_value = None
|
||||
|
||||
@@ -257,137 +271,98 @@ class TestDecodeJwtToken:
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_app_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
) -> None:
|
||||
non_existent_id = str(uuid4())
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_id": non_existent_id,
|
||||
"end_user_id": str(uuid4()),
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.return_value = None # No app found
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_disabled_site_raises_bad_request(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=False)
|
||||
|
||||
session_mock = MagicMock()
|
||||
# scalar calls: app_model, site (code found), then end_user
|
||||
session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(BadRequest, match="Site is disabled"):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_missing_end_user_raises_not_found(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
|
||||
non_existent_eu = str(uuid4())
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": non_existent_eu,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, None] # end_user is None
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(NotFound):
|
||||
decode_jwt_token()
|
||||
|
||||
@patch("controllers.web.wraps.FeatureService.get_system_features")
|
||||
@patch("controllers.web.wraps.PassportService")
|
||||
@patch("controllers.web.wraps.extract_webapp_passport")
|
||||
@patch("controllers.web.wraps.db")
|
||||
def test_user_id_mismatch_raises_unauthorized(
|
||||
self,
|
||||
mock_db: MagicMock,
|
||||
mock_extract: MagicMock,
|
||||
mock_passport_cls: MagicMock,
|
||||
mock_features: MagicMock,
|
||||
app: Flask,
|
||||
app,
|
||||
db_session_with_containers: Session,
|
||||
) -> None:
|
||||
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
|
||||
|
||||
mock_extract.return_value = "jwt-token"
|
||||
mock_passport_cls.return_value.verify.return_value = {
|
||||
"app_code": "code1",
|
||||
"app_id": "app-1",
|
||||
"end_user_id": "eu-1",
|
||||
"app_code": site.code,
|
||||
"app_id": app_model.id,
|
||||
"end_user_id": end_user.id,
|
||||
}
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", enable_site=True)
|
||||
site = SimpleNamespace(code="code1")
|
||||
end_user = SimpleNamespace(id="eu-1", session_id="sess-1")
|
||||
|
||||
session_mock = MagicMock()
|
||||
session_mock.scalar.side_effect = [app_model, site, end_user]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__ = MagicMock(return_value=session_mock)
|
||||
session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_db.engine = "engine"
|
||||
|
||||
with patch("controllers.web.wraps.Session", return_value=session_ctx):
|
||||
with app.test_request_context("/", headers={"X-App-Code": "code1"}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
with app.test_request_context("/", headers={"X-App-Code": site.code}):
|
||||
with pytest.raises(Unauthorized, match="expired"):
|
||||
decode_jwt_token(user_id="different-user")
|
||||
@@ -141,7 +141,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant_id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
@@ -298,7 +298,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
@@ -417,7 +417,7 @@ class TestModelLoadBalancingService:
|
||||
tenant_id=tenant.id,
|
||||
provider_name="openai",
|
||||
model_name="gpt-3.5-turbo",
|
||||
model_type="text-generation", # Use the origin model type that matches the query
|
||||
model_type="llm",
|
||||
name="config1",
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
class _SessionContext:
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
|
||||
def __enter__(self):
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session))
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.commit.assert_called_once()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
|
||||
session = MagicMock()
|
||||
_install_session(monkeypatch, session)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"check_dependencies",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}),
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
@@ -1,142 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console.workspace import plugin_permission_required
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
|
||||
class _SessionStub:
|
||||
def __init__(self, permission):
|
||||
self._permission = permission
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._permission
|
||||
|
||||
|
||||
def _workspace_module():
|
||||
return importlib.import_module(plugin_permission_required.__module__)
|
||||
|
||||
|
||||
def _patch_session(monkeypatch: pytest.MonkeyPatch, permission):
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission))
|
||||
monkeypatch.setattr(module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
|
||||
def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, None)
|
||||
|
||||
@plugin_permission_required()
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.NOBODY,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(install_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
assert handler() == "ok"
|
||||
|
||||
|
||||
def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=True)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.NOBODY,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
|
||||
|
||||
def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(is_admin_or_owner=False)
|
||||
permission = SimpleNamespace(
|
||||
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
|
||||
debug_permission=TenantPluginPermission.DebugPermission.ADMINS,
|
||||
)
|
||||
module = _workspace_module()
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
_patch_session(monkeypatch, permission)
|
||||
|
||||
@plugin_permission_required(debug_required=True)
|
||||
def handler():
|
||||
return "ok"
|
||||
|
||||
with pytest.raises(Forbidden):
|
||||
handler()
|
||||
@@ -768,6 +768,7 @@ class TestSegmentApiGet:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -780,6 +781,7 @@ class TestSegmentApiGet:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -791,7 +793,8 @@ class TestSegmentApiGet:
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_seg_svc.get_segments.return_value = ([mock_segment], 1)
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
# Act
|
||||
with app.test_request_context(
|
||||
@@ -872,6 +875,7 @@ class TestSegmentApiPost:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -888,6 +892,7 @@ class TestSegmentApiPost:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -909,7 +914,8 @@ class TestSegmentApiPost:
|
||||
|
||||
mock_seg_svc.segment_create_args_validate.return_value = None
|
||||
mock_seg_svc.multi_create_segment.return_value = [mock_segment]
|
||||
mock_marshal.return_value = [{"id": mock_segment.id}]
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segments_summaries.return_value = {}
|
||||
|
||||
segments_data = [{"content": "Test segment content", "answer": "Test answer"}]
|
||||
|
||||
@@ -1206,6 +1212,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_rate_limit.enabled = False
|
||||
mock_feature_svc.get_knowledge_rate_limit.return_value = mock_rate_limit
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1224,6 +1231,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1240,6 +1248,7 @@ class TestDatasetSegmentApiUpdate:
|
||||
updated = Mock()
|
||||
mock_seg_svc.update_segment.return_value = updated
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1349,6 +1358,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
``current_account_with_tenant()`` and ``marshal``.
|
||||
"""
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@@ -1363,6 +1373,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
@@ -1376,6 +1387,7 @@ class TestDatasetSegmentApiGetSingle:
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id}
|
||||
mock_summary_svc.get_segment_summary.return_value = None
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
@@ -1393,6 +1405,55 @@ class TestDatasetSegmentApiGetSingle:
|
||||
assert "data" in response
|
||||
assert response["doc_form"] == IndexStructureType.PARAGRAPH_INDEX
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.SummaryIndexService")
|
||||
@patch("controllers.service_api.dataset.segment.marshal")
|
||||
@patch("controllers.service_api.dataset.segment.SegmentService")
|
||||
@patch("controllers.service_api.dataset.segment.DocumentService")
|
||||
@patch("controllers.service_api.dataset.segment.DatasetService")
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_includes_summary(
|
||||
self,
|
||||
mock_db,
|
||||
mock_account_fn,
|
||||
mock_dataset_svc,
|
||||
mock_doc_svc,
|
||||
mock_seg_svc,
|
||||
mock_marshal,
|
||||
mock_summary_svc,
|
||||
app,
|
||||
mock_tenant,
|
||||
mock_dataset,
|
||||
mock_segment,
|
||||
):
|
||||
"""Test that single segment response includes summary content from SummaryIndexService."""
|
||||
mock_account_fn.return_value = (Mock(), mock_tenant.id)
|
||||
mock_db.session.scalar.return_value = mock_dataset
|
||||
mock_dataset_svc.check_dataset_model_setting.return_value = None
|
||||
mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX)
|
||||
mock_doc_svc.get_document.return_value = mock_doc
|
||||
mock_seg_svc.get_segment_by_id.return_value = mock_segment
|
||||
mock_marshal.return_value = {"id": mock_segment.id, "summary": None}
|
||||
|
||||
mock_summary_record = Mock()
|
||||
mock_summary_record.summary_content = "This is the segment summary"
|
||||
mock_summary_svc.get_segment_summary.return_value = mock_summary_record
|
||||
|
||||
with app.test_request_context(
|
||||
f"/datasets/{mock_dataset.id}/documents/doc-id/segments/{mock_segment.id}",
|
||||
method="GET",
|
||||
):
|
||||
api = DatasetSegmentApi()
|
||||
response, status = api.get(
|
||||
tenant_id=mock_tenant.id,
|
||||
dataset_id=mock_dataset.id,
|
||||
document_id="doc-id",
|
||||
segment_id=mock_segment.id,
|
||||
)
|
||||
|
||||
assert status == 200
|
||||
assert response["data"]["summary"] == "This is the segment summary"
|
||||
|
||||
@patch("controllers.service_api.dataset.segment.current_account_with_tenant")
|
||||
@patch("controllers.service_api.dataset.segment.db")
|
||||
def test_get_single_segment_dataset_not_found(
|
||||
|
||||
@@ -415,12 +415,44 @@ class TestUtilityFunctions:
|
||||
label="Upload",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
variable="enabled",
|
||||
description="Enable flag",
|
||||
label="Enabled",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="config",
|
||||
description="Config object",
|
||||
label="Config",
|
||||
required=True,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="schema_config",
|
||||
description="Config with schema",
|
||||
label="Schema Config",
|
||||
required=False,
|
||||
json_schema={
|
||||
"properties": {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "number"},
|
||||
},
|
||||
"required": ["host"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict: dict[str, str] = {
|
||||
"name": "Enter your name",
|
||||
"category": "Select category",
|
||||
"count": "Enter count",
|
||||
"enabled": "Enable flag",
|
||||
"config": "Config object",
|
||||
"schema_config": "Config with schema",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
@@ -437,20 +469,35 @@ class TestUtilityFunctions:
|
||||
assert "count" in parameters
|
||||
assert parameters["count"]["type"] == "number"
|
||||
|
||||
# FILE type should be skipped - it creates empty dict but gets filtered later
|
||||
# Check that it doesn't have any meaningful content
|
||||
if "upload" in parameters:
|
||||
assert parameters["upload"] == {}
|
||||
# FILE type is skipped entirely via `continue` — key should not exist
|
||||
assert "upload" not in parameters
|
||||
|
||||
# CHECKBOX maps to boolean
|
||||
assert parameters["enabled"]["type"] == "boolean"
|
||||
|
||||
# JSON_OBJECT without json_schema maps to object
|
||||
assert parameters["config"]["type"] == "object"
|
||||
assert "properties" not in parameters["config"]
|
||||
|
||||
# JSON_OBJECT with json_schema forwards schema keys
|
||||
assert parameters["schema_config"]["type"] == "object"
|
||||
assert parameters["schema_config"]["properties"] == {
|
||||
"host": {"type": "string"},
|
||||
"port": {"type": "number"},
|
||||
}
|
||||
assert parameters["schema_config"]["required"] == ["host"]
|
||||
assert parameters["schema_config"]["additionalProperties"] is False
|
||||
|
||||
# Check required fields
|
||||
assert "name" in required
|
||||
assert "count" in required
|
||||
assert "config" in required
|
||||
assert "category" not in required
|
||||
|
||||
# Note: _get_request_id function has been removed as request_id is now passed as parameter
|
||||
|
||||
def test_convert_input_form_to_parameters_jsonschema_validation_ok(self):
|
||||
"""Current schema uses 'number' for numeric fields; it should be a valid JSON Schema."""
|
||||
"""Generated schema with all supported types should be valid JSON Schema."""
|
||||
user_input_form = [
|
||||
VariableEntity(
|
||||
type=VariableEntityType.NUMBER,
|
||||
@@ -466,11 +513,27 @@ class TestUtilityFunctions:
|
||||
label="Name",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.CHECKBOX,
|
||||
variable="enabled",
|
||||
description="Toggle",
|
||||
label="Enabled",
|
||||
required=False,
|
||||
),
|
||||
VariableEntity(
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
variable="metadata",
|
||||
description="Metadata",
|
||||
label="Metadata",
|
||||
required=False,
|
||||
),
|
||||
]
|
||||
|
||||
parameters_dict = {
|
||||
"count": "Enter count",
|
||||
"name": "Enter your name",
|
||||
"enabled": "Toggle flag",
|
||||
"metadata": "Metadata object",
|
||||
}
|
||||
|
||||
parameters, required = convert_input_form_to_parameters(user_input_form, parameters_dict)
|
||||
@@ -485,9 +548,12 @@ class TestUtilityFunctions:
|
||||
# 1) The schema itself must be valid
|
||||
jsonschema.Draft202012Validator.check_schema(schema)
|
||||
|
||||
# 2) Both float and integer instances should pass validation
|
||||
# 2) Validate instances with all types
|
||||
jsonschema.validate(instance={"count": 3.14, "name": "alice"}, schema=schema)
|
||||
jsonschema.validate(instance={"count": 2, "name": "bob"}, schema=schema)
|
||||
jsonschema.validate(
|
||||
instance={"count": 2, "enabled": True, "metadata": {"key": "val"}},
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
def test_legacy_float_type_schema_is_invalid(self):
|
||||
"""Legacy/buggy behavior: using 'float' should produce an invalid JSON Schema."""
|
||||
|
||||
@@ -521,11 +521,11 @@ def test_generate_name_trace(trace_instance):
|
||||
def test_add_trace_success(trace_instance):
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
trace_instance.add_trace(data)
|
||||
trace_instance.langfuse_client.trace.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_trace_error(trace_instance):
|
||||
trace_instance.langfuse_client.trace.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseTrace(id="t1", name="trace")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"):
|
||||
trace_instance.add_trace(data)
|
||||
@@ -534,11 +534,11 @@ def test_add_trace_error(trace_instance):
|
||||
def test_add_span_success(trace_instance):
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
trace_instance.add_span(data)
|
||||
trace_instance.langfuse_client.span.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_span_error(trace_instance):
|
||||
trace_instance.langfuse_client.span.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseSpan(id="s1", name="span", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create span: error"):
|
||||
trace_instance.add_span(data)
|
||||
@@ -554,11 +554,11 @@ def test_update_span(trace_instance):
|
||||
def test_add_generation_success(trace_instance):
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
trace_instance.add_generation(data)
|
||||
trace_instance.langfuse_client.generation.assert_called_once()
|
||||
trace_instance.langfuse_client.api.ingestion.batch.assert_called_once()
|
||||
|
||||
|
||||
def test_add_generation_error(trace_instance):
|
||||
trace_instance.langfuse_client.generation.side_effect = Exception("error")
|
||||
trace_instance.langfuse_client.api.ingestion.batch.side_effect = Exception("error")
|
||||
data = LangfuseGeneration(id="g1", name="gen", trace_id="t1")
|
||||
with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"):
|
||||
trace_instance.add_generation(data)
|
||||
@@ -585,12 +585,12 @@ def test_api_check_error(trace_instance):
|
||||
def test_get_project_key_success(trace_instance):
|
||||
mock_data = MagicMock()
|
||||
mock_data.id = "proj-1"
|
||||
trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
trace_instance.langfuse_client.api.projects.get.return_value = MagicMock(data=[mock_data])
|
||||
assert trace_instance.get_project_key() == "proj-1"
|
||||
|
||||
|
||||
def test_get_project_key_error(trace_instance):
|
||||
trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail")
|
||||
trace_instance.langfuse_client.api.projects.get.side_effect = Exception("fail")
|
||||
with pytest.raises(ValueError, match="LangFuse get project key failed: fail"):
|
||||
trace_instance.get_project_key()
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
@@ -61,7 +61,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
@@ -70,7 +70,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
@@ -110,7 +110,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
@@ -121,7 +121,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
@@ -157,7 +157,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
@@ -168,7 +168,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
@@ -177,7 +177,7 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
model_type="llm",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
@@ -270,7 +270,7 @@ def test_get_default_model_uses_injected_runtime_for_existing_default_record(moc
|
||||
tenant_id="tenant-id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type=ModelType.LLM.to_origin_model_type(),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = existing_default_model
|
||||
@@ -449,7 +449,7 @@ def test_update_default_model_record_updates_existing_record(mocker: MockerFixtu
|
||||
tenant_id="tenant-id",
|
||||
provider_name="anthropic",
|
||||
model_name="claude-3-sonnet",
|
||||
model_type=ModelType.LLM.to_origin_model_type(),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = existing_default_model
|
||||
@@ -487,7 +487,7 @@ def test_update_default_model_record_creates_record_with_origin_model_type(mocke
|
||||
assert created_default_model.tenant_id == "tenant-id"
|
||||
assert created_default_model.provider_name == "openai"
|
||||
assert created_default_model.model_name == "gpt-4"
|
||||
assert created_default_model.model_type == ModelType.LLM.to_origin_model_type()
|
||||
assert created_default_model.model_type == ModelType.LLM
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -202,7 +202,7 @@ class TestProviderModel:
|
||||
# Assert
|
||||
assert provider.provider_type == ProviderType.CUSTOM
|
||||
assert provider.is_valid is False
|
||||
assert provider.quota_type == ""
|
||||
assert provider.quota_type is None
|
||||
assert provider.quota_limit is None
|
||||
assert provider.quota_used == 0
|
||||
assert provider.credential_id is None
|
||||
|
||||
@@ -5,6 +5,7 @@ Covers:
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -15,9 +16,178 @@ from services.enterprise.enterprise_service import (
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
WebAppSettings,
|
||||
WorkspacePermission,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestEnterpriseServiceInfo:
|
||||
def test_get_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"version": "1.0"}
|
||||
result = EnterpriseService.get_info()
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/info")
|
||||
assert result == {"version": "1.0"}
|
||||
|
||||
def test_get_workspace_info_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"name": "ws"}
|
||||
result = EnterpriseService.get_workspace_info("tenant-1")
|
||||
|
||||
req.send_request.assert_called_once_with("GET", "/workspace/tenant-1/info")
|
||||
assert result == {"name": "ws"}
|
||||
|
||||
|
||||
class TestSsoSettingsLastUpdateTime:
|
||||
def test_app_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-01-15T10:30:00+00:00"
|
||||
result = EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
assert result.year == 2025
|
||||
|
||||
def test_app_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = ""
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_app_sso_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "not-a-date"
|
||||
with pytest.raises(ValueError, match="Invalid date format"):
|
||||
EnterpriseService.get_app_sso_settings_last_update_time()
|
||||
|
||||
def test_workspace_sso_parses_valid_timestamp(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = "2025-06-01T00:00:00+00:00"
|
||||
result = EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_workspace_sso_raises_on_empty(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.get_workspace_sso_settings_last_update_time()
|
||||
|
||||
|
||||
class TestWorkspacePermissionService:
|
||||
def test_raises_on_empty_workspace_id(self):
|
||||
with pytest.raises(ValueError, match="workspace_id must be provided"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("")
|
||||
|
||||
def test_raises_on_missing_data(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_raises_on_missing_permission_key(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"other": "data"}
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
def test_returns_parsed_permission(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {
|
||||
"permission": {
|
||||
"workspaceId": "ws-1",
|
||||
"allowMemberInvite": True,
|
||||
"allowOwnerTransfer": False,
|
||||
}
|
||||
}
|
||||
result = EnterpriseService.WorkspacePermissionService.get_permission("ws-1")
|
||||
|
||||
assert isinstance(result, WorkspacePermission)
|
||||
assert result.workspace_id == "ws-1"
|
||||
assert result.allow_member_invite is True
|
||||
assert result.allow_owner_transfer is False
|
||||
|
||||
|
||||
class TestWebAppAuth:
|
||||
def test_is_user_allowed_returns_result_field(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is True
|
||||
|
||||
def test_is_user_allowed_defaults_false(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {}
|
||||
assert EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp("u1", "a1") is False
|
||||
|
||||
def test_batch_is_user_allowed_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", []) == {}
|
||||
|
||||
def test_batch_is_user_allowed_raises_on_empty_response(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = None
|
||||
with pytest.raises(ValueError, match="No data found"):
|
||||
EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps("u1", ["a1"])
|
||||
|
||||
def test_get_app_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id("")
|
||||
|
||||
def test_get_app_access_mode_returns_settings(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessMode": "public"}
|
||||
result = EnterpriseService.WebAppAuth.get_app_access_mode_by_id("a1")
|
||||
|
||||
assert isinstance(result, WebAppSettings)
|
||||
assert result.access_mode == "public"
|
||||
|
||||
def test_batch_get_returns_empty_for_no_apps(self):
|
||||
assert EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id([]) == {}
|
||||
|
||||
def test_batch_get_maps_access_modes(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": {"a1": "public", "a2": "private"}}
|
||||
result = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1", "a2"])
|
||||
|
||||
assert result["a1"].access_mode == "public"
|
||||
assert result["a2"].access_mode == "private"
|
||||
|
||||
def test_batch_get_raises_on_invalid_format(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"accessModes": "not-a-dict"}
|
||||
with pytest.raises(ValueError, match="Invalid data format"):
|
||||
EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(["a1"])
|
||||
|
||||
def test_update_access_mode_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("", "public")
|
||||
|
||||
def test_update_access_mode_raises_on_invalid_mode(self):
|
||||
with pytest.raises(ValueError, match="access_mode must be"):
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode("a1", "invalid")
|
||||
|
||||
def test_update_access_mode_delegates_and_returns(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
result = EnterpriseService.WebAppAuth.update_app_access_mode("a1", "public")
|
||||
|
||||
assert result is True
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST", "/webapp/access-mode", json={"appId": "a1", "accessMode": "public"}
|
||||
)
|
||||
|
||||
def test_cleanup_webapp_raises_on_empty_app_id(self):
|
||||
with pytest.raises(ValueError, match="app_id must be provided"):
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("")
|
||||
|
||||
def test_cleanup_webapp_delegates(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
EnterpriseService.WebAppAuth.cleanup_webapp("a1")
|
||||
|
||||
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
|
||||
@@ -7,14 +7,20 @@ This module covers the pre-uninstall plugin hook behavior:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
CheckCredentialPolicyComplianceRequest,
|
||||
CredentialPolicyViolationError,
|
||||
PluginCredentialType,
|
||||
PluginManagerService,
|
||||
PreUninstallPluginRequest,
|
||||
)
|
||||
|
||||
MODULE = "services.enterprise.plugin_manager_service"
|
||||
|
||||
|
||||
class TestTryPreUninstallPlugin:
|
||||
def test_try_pre_uninstall_plugin_success(self):
|
||||
@@ -88,3 +94,46 @@ class TestTryPreUninstallPlugin:
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
class TestCheckCredentialPolicyCompliance:
|
||||
def _request(self, cred_type=PluginCredentialType.MODEL):
|
||||
return CheckCredentialPolicyComplianceRequest(
|
||||
dify_credential_id="cred-1", provider="openai", credential_type=cred_type
|
||||
)
|
||||
|
||||
def test_passes_when_result_true(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": True}
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
req.send_request.assert_called_once()
|
||||
|
||||
def test_raises_violation_when_result_false(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = {"result": False}
|
||||
with pytest.raises(CredentialPolicyViolationError, match="Credentials not available"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_invalid_response_format(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.return_value = "not-a-dict"
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_raises_violation_on_api_exception(self):
|
||||
with patch(f"{MODULE}.EnterprisePluginManagerRequest") as req:
|
||||
req.send_request.side_effect = ConnectionError("network fail")
|
||||
with pytest.raises(CredentialPolicyViolationError, match="error occurred"):
|
||||
PluginManagerService.check_credential_policy_compliance(self._request())
|
||||
|
||||
def test_model_dump_serializes_credential_type_as_number(self):
|
||||
body = self._request(PluginCredentialType.TOOL)
|
||||
data = body.model_dump()
|
||||
|
||||
assert data["credential_type"] == 1
|
||||
assert data["dify_credential_id"] == "cred-1"
|
||||
|
||||
def test_model_credential_type_values(self):
|
||||
assert PluginCredentialType.MODEL.to_number() == 0
|
||||
assert PluginCredentialType.TOOL.to_number() == 1
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
MODULE = "services.plugin.plugin_auto_upgrade_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetStrategy:
|
||||
def test_returns_strategy_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
strategy = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = strategy
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is strategy
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.get_strategy("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangeStrategy:
|
||||
def test_creates_new_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
3,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_strategy(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.change_strategy(
|
||||
"t1",
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST,
|
||||
5,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL,
|
||||
["p1"],
|
||||
["p2"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert existing.strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST
|
||||
assert existing.upgrade_time_of_day == 5
|
||||
assert existing.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
assert existing.include_plugins == ["p2"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestExcludePlugin:
|
||||
def test_creates_default_strategy_when_none_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with (
|
||||
p1,
|
||||
p2,
|
||||
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
|
||||
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
|
||||
):
|
||||
strat_cls.StrategySetting.FIX_ONLY = "fix_only"
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
cs.return_value = True
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "plugin-1")
|
||||
|
||||
assert result is True
|
||||
cs.assert_called_once()
|
||||
|
||||
def test_appends_to_exclude_list_in_exclude_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p-existing"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p-new")
|
||||
|
||||
assert result is True
|
||||
assert existing.exclude_plugins == ["p-existing", "p-new"]
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_removes_from_include_list_in_partial_mode(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "partial"
|
||||
existing.include_plugins = ["p1", "p2"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.include_plugins == ["p2"]
|
||||
|
||||
def test_switches_to_exclude_mode_from_all(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "all"
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
result = PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert result is True
|
||||
assert existing.upgrade_mode == "exclude"
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
|
||||
def test_no_duplicate_in_exclude_list(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
existing.upgrade_mode = "exclude"
|
||||
existing.exclude_plugins = ["p1"]
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
|
||||
strat_cls.UpgradeMode.EXCLUDE = "exclude"
|
||||
strat_cls.UpgradeMode.PARTIAL = "partial"
|
||||
strat_cls.UpgradeMode.ALL = "all"
|
||||
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
|
||||
|
||||
PluginAutoUpgradeService.exclude_plugin("t1", "p1")
|
||||
|
||||
assert existing.exclude_plugins == ["p1"]
|
||||
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from models.account import TenantPluginPermission
|
||||
|
||||
MODULE = "services.plugin.plugin_permission_service"
|
||||
|
||||
|
||||
def _patched_session():
|
||||
"""Patch Session(db.engine) to return a mock session as context manager."""
|
||||
session = MagicMock()
|
||||
session_cls = MagicMock()
|
||||
session_cls.return_value.__enter__ = MagicMock(return_value=session)
|
||||
session_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
patcher = patch(f"{MODULE}.Session", session_cls)
|
||||
db_patcher = patch(f"{MODULE}.db")
|
||||
return patcher, db_patcher, session
|
||||
|
||||
|
||||
class TestGetPermission:
|
||||
def test_returns_permission_when_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
permission = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = permission
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is permission
|
||||
|
||||
def test_returns_none_when_not_found(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.get_permission("t1")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestChangePermission:
|
||||
def test_creates_new_permission_when_not_exists(self):
|
||||
p1, p2, session = _patched_session()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with p1, p2, patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
|
||||
perm_cls.return_value = MagicMock()
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
|
||||
)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_updates_existing_permission(self):
|
||||
p1, p2, session = _patched_session()
|
||||
existing = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = existing
|
||||
|
||||
with p1, p2:
|
||||
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||
|
||||
result = PluginPermissionService.change_permission(
|
||||
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
|
||||
)
|
||||
|
||||
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
|
||||
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_not_called()
|
||||
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
0
api/tests/unit_tests/services/retention/__init__.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from services.retention.conversation.messages_clean_policy import (
|
||||
BillingDisabledPolicy,
|
||||
BillingSandboxPolicy,
|
||||
SimpleMessage,
|
||||
create_message_clean_policy,
|
||||
)
|
||||
|
||||
MODULE = "services.retention.conversation.messages_clean_policy"
|
||||
|
||||
|
||||
def _msg(msg_id: str, app_id: str, days_ago: int = 0) -> SimpleMessage:
|
||||
return SimpleMessage(
|
||||
id=msg_id,
|
||||
app_id=app_id,
|
||||
created_at=datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days_ago),
|
||||
)
|
||||
|
||||
|
||||
class TestBillingDisabledPolicy:
|
||||
def test_returns_all_message_ids(self):
|
||||
policy = BillingDisabledPolicy()
|
||||
msgs = [_msg("m1", "app1"), _msg("m2", "app2"), _msg("m3", "app1")]
|
||||
|
||||
result = policy.filter_message_ids(msgs, {"app1": "t1", "app2": "t2"})
|
||||
|
||||
assert set(result) == {"m1", "m2", "m3"}
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
assert BillingDisabledPolicy().filter_message_ids([], {}) == []
|
||||
|
||||
|
||||
class TestBillingSandboxPolicy:
|
||||
def _policy(self, plans, *, graceful_days=21, whitelist=None, now=1_000_000_000):
|
||||
return BillingSandboxPolicy(
|
||||
plan_provider=lambda _ids: plans,
|
||||
graceful_period_days=graceful_days,
|
||||
tenant_whitelist=whitelist,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
def test_empty_messages_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([], {"app1": "t1"}) == []
|
||||
|
||||
def test_empty_app_to_tenant_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
assert policy.filter_message_ids([_msg("m1", "app1")], {}) == []
|
||||
|
||||
def test_empty_plans_returns_empty(self):
|
||||
policy = self._policy({})
|
||||
msgs = [_msg("m1", "app1")]
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_non_sandbox_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "professional", "expiration_date": 0}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_sandbox_no_previous_subscription_deletes(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_expired_beyond_grace_period_deletes(self):
|
||||
now = 1_000_000_000
|
||||
expired_long_ago = now - (22 * 24 * 60 * 60) # 22 days ago > 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_long_ago}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == ["m1"]
|
||||
|
||||
def test_sandbox_within_grace_period_kept(self):
|
||||
now = 1_000_000_000
|
||||
expired_recently = now - (10 * 24 * 60 * 60) # 10 days ago < 21 day grace
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": expired_recently}}
|
||||
policy = self._policy(plans, now=now)
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_whitelisted_tenant_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans, whitelist=["t1"])
|
||||
msgs = [_msg("m1", "app1")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_message_without_tenant_mapping_skipped(self):
|
||||
plans = {"t1": {"plan": "sandbox", "expiration_date": -1}}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "unmapped_app")]
|
||||
|
||||
assert policy.filter_message_ids(msgs, {"app1": "t1"}) == []
|
||||
|
||||
def test_mixed_tenants_only_sandbox_deleted(self):
|
||||
plans = {
|
||||
"t_sandbox": {"plan": "sandbox", "expiration_date": -1},
|
||||
"t_pro": {"plan": "professional", "expiration_date": 0},
|
||||
}
|
||||
policy = self._policy(plans)
|
||||
msgs = [_msg("m1", "app_sandbox"), _msg("m2", "app_pro")]
|
||||
app_map = {"app_sandbox": "t_sandbox", "app_pro": "t_pro"}
|
||||
|
||||
result = policy.filter_message_ids(msgs, app_map)
|
||||
|
||||
assert result == ["m1"]
|
||||
|
||||
|
||||
class TestCreateMessageCleanPolicy:
|
||||
def test_billing_disabled_returns_disabled_policy(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.BILLING_ENABLED = False
|
||||
policy = create_message_clean_policy()
|
||||
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
def test_billing_enabled_returns_sandbox_policy(self):
|
||||
with (
|
||||
patch(f"{MODULE}.dify_config") as cfg,
|
||||
patch(f"{MODULE}.BillingService") as bs,
|
||||
):
|
||||
cfg.BILLING_ENABLED = True
|
||||
bs.get_expired_subscription_cleanup_whitelist.return_value = ["wl1"]
|
||||
bs.get_plan_bulk_with_cache = MagicMock()
|
||||
policy = create_message_clean_policy(graceful_period_days=30)
|
||||
|
||||
assert isinstance(policy, BillingSandboxPolicy)
|
||||
@@ -317,7 +317,7 @@ def test_init_inherit_config_should_create_and_persist_inherit_configuration(
|
||||
assert inherit_config.tenant_id == "tenant-1"
|
||||
assert inherit_config.provider_name == "openai"
|
||||
assert inherit_config.model_name == "gpt-4o-mini"
|
||||
assert inherit_config.model_type == "text-generation"
|
||||
assert inherit_config.model_type == "llm"
|
||||
assert inherit_config.name == "__inherit__"
|
||||
mock_db.session.add.assert_called_once_with(inherit_config)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@@ -0,0 +1,598 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderType
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
MODULE = "services.tools.tools_transform_service"
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Test cases for ToolTransformService.convert_tool_entity_to_api_entity method"""
|
||||
|
||||
def test_convert_tool_with_parameter_override(self):
|
||||
"""Test that runtime parameters correctly override base parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
# Create mock runtime parameters that override base parameters
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1" # Different label to verify override
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_tool"
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Find the overridden parameter
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1" # Should be runtime version
|
||||
|
||||
# Find the non-overridden parameter
|
||||
original_param = next((p for p in result.parameters if p.name == "param2"), None)
|
||||
assert original_param is not None
|
||||
assert original_param.label == "Base Param 2" # Should be base version
|
||||
|
||||
def test_convert_tool_with_additional_runtime_parameters(self):
|
||||
"""Test that additional runtime parameters are added to the final list"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters - one that overrides and one that's new
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "runtime_only"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Only Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 2
|
||||
|
||||
# Check that both parameters are present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "runtime_only" in param_names
|
||||
|
||||
# Verify the overridden parameter has runtime version
|
||||
overridden_param = next((p for p in result.parameters if p.name == "param1"), None)
|
||||
assert overridden_param is not None
|
||||
assert overridden_param.label == "Runtime Param 1"
|
||||
|
||||
# Verify the new runtime parameter is included
|
||||
new_param = next((p for p in result.parameters if p.name == "runtime_only"), None)
|
||||
assert new_param is not None
|
||||
assert new_param.label == "Runtime Only Param"
|
||||
|
||||
def test_convert_tool_with_non_form_runtime_parameters(self):
|
||||
"""Test that non-FORM runtime parameters are not added as new parameters"""
|
||||
# Create mock base parameters
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
# Create mock runtime parameters with different forms
|
||||
runtime_param1 = Mock(spec=ToolParameter)
|
||||
runtime_param1.name = "param1"
|
||||
runtime_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param1.type = "string"
|
||||
runtime_param1.label = "Runtime Param 1"
|
||||
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "llm_param"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.LLM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "LLM Param"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param1, runtime_param2]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 1 # Only the FORM parameter should be present
|
||||
|
||||
# Check that only the FORM parameter is present
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert "param1" in param_names
|
||||
assert "llm_param" not in param_names
|
||||
|
||||
def test_convert_tool_with_empty_parameters(self):
|
||||
"""Test conversion with empty base and runtime parameters"""
|
||||
# Create mock tool with no parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = []
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_with_none_parameters(self):
|
||||
"""Test conversion when base parameters is None"""
|
||||
# Create mock tool with None parameters
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = None
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = []
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 0
|
||||
|
||||
def test_convert_tool_parameter_order_preserved(self):
|
||||
"""Test that parameter order is preserved correctly"""
|
||||
# Create mock base parameters in specific order
|
||||
base_param1 = Mock(spec=ToolParameter)
|
||||
base_param1.name = "param1"
|
||||
base_param1.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param1.type = "string"
|
||||
base_param1.label = "Base Param 1"
|
||||
|
||||
base_param2 = Mock(spec=ToolParameter)
|
||||
base_param2.name = "param2"
|
||||
base_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param2.type = "string"
|
||||
base_param2.label = "Base Param 2"
|
||||
|
||||
base_param3 = Mock(spec=ToolParameter)
|
||||
base_param3.name = "param3"
|
||||
base_param3.form = ToolParameter.ToolParameterForm.FORM
|
||||
base_param3.type = "string"
|
||||
base_param3.label = "Base Param 3"
|
||||
|
||||
# Create runtime parameter that overrides middle parameter
|
||||
runtime_param2 = Mock(spec=ToolParameter)
|
||||
runtime_param2.name = "param2"
|
||||
runtime_param2.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param2.type = "string"
|
||||
runtime_param2.label = "Runtime Param 2"
|
||||
|
||||
# Create new runtime parameter
|
||||
runtime_param4 = Mock(spec=ToolParameter)
|
||||
runtime_param4.name = "param4"
|
||||
runtime_param4.form = ToolParameter.ToolParameterForm.FORM
|
||||
runtime_param4.type = "string"
|
||||
runtime_param4.label = "Runtime Param 4"
|
||||
|
||||
# Create mock tool
|
||||
mock_tool = Mock(spec=Tool)
|
||||
mock_tool.entity = Mock()
|
||||
mock_tool.entity.parameters = [base_param1, base_param2, base_param3]
|
||||
mock_tool.entity.identity = Mock()
|
||||
mock_tool.entity.identity.author = "test_author"
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
mock_tool.entity.identity.label = I18nObject(en_US="Test Tool")
|
||||
mock_tool.entity.description = Mock()
|
||||
mock_tool.entity.description.human = I18nObject(en_US="Test description")
|
||||
mock_tool.entity.output_schema = {}
|
||||
mock_tool.get_runtime_parameters.return_value = [runtime_param2, runtime_param4]
|
||||
|
||||
# Mock fork_tool_runtime to return the same tool
|
||||
mock_tool.fork_tool_runtime.return_value = mock_tool
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.convert_tool_entity_to_api_entity(mock_tool, "test_tenant", None)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolApiEntity)
|
||||
assert result.parameters is not None
|
||||
assert len(result.parameters) == 4
|
||||
|
||||
# Check that order is maintained: base parameters first, then new runtime parameters
|
||||
param_names = [p.name for p in result.parameters]
|
||||
assert param_names == ["param1", "param2", "param3", "param4"]
|
||||
|
||||
# Verify that param2 was overridden with runtime version
|
||||
param2 = result.parameters[1]
|
||||
assert param2.name == "param2"
|
||||
assert param2.label == "Runtime Param 2"
|
||||
|
||||
|
||||
class TestWorkflowProviderToUserProvider:
|
||||
"""Test cases for ToolTransformService.workflow_provider_to_user_provider method"""
|
||||
|
||||
def test_workflow_provider_to_user_provider_with_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider correctly sets workflow_app_id."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
workflow_app_id = "app_123"
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1", "label2"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "test_author"
|
||||
assert result.name == "test_workflow_tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_without_workflow_app_id(self):
|
||||
"""Test that workflow_provider_to_user_provider works when workflow_app_id is not provided."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method without workflow_app_id
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["label1"],
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == ["label1"]
|
||||
|
||||
def test_workflow_provider_to_user_provider_workflow_app_id_none(self):
|
||||
"""Test that workflow_provider_to_user_provider handles None workflow_app_id explicitly."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller
|
||||
provider_id = "provider_123"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "test_author"
|
||||
mock_controller.entity.identity.name = "test_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(en_US="Test description")
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.icon_dark = None
|
||||
mock_controller.entity.identity.label = I18nObject(en_US="Test Workflow Tool")
|
||||
|
||||
# Call the method with explicit None values
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=None,
|
||||
workflow_app_id=None,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.workflow_app_id is None
|
||||
assert result.labels == []
|
||||
|
||||
def test_workflow_provider_to_user_provider_preserves_other_fields(self):
|
||||
"""Test that workflow_provider_to_user_provider preserves all other entity fields."""
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
|
||||
# Create mock workflow tool provider controller with various fields
|
||||
workflow_app_id = "app_456"
|
||||
provider_id = "provider_456"
|
||||
mock_controller = Mock(spec=WorkflowToolProviderController)
|
||||
mock_controller.provider_id = provider_id
|
||||
mock_controller.entity = Mock()
|
||||
mock_controller.entity.identity = Mock()
|
||||
mock_controller.entity.identity.author = "another_author"
|
||||
mock_controller.entity.identity.name = "another_workflow_tool"
|
||||
mock_controller.entity.identity.description = I18nObject(
|
||||
en_US="Another description", zh_Hans="Another description"
|
||||
)
|
||||
mock_controller.entity.identity.icon = {"type": "emoji", "content": "⚙️"}
|
||||
mock_controller.entity.identity.icon_dark = {"type": "emoji", "content": "🔧"}
|
||||
mock_controller.entity.identity.label = I18nObject(
|
||||
en_US="Another Workflow Tool", zh_Hans="Another Workflow Tool"
|
||||
)
|
||||
|
||||
# Call the method
|
||||
result = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=mock_controller,
|
||||
labels=["automation", "workflow"],
|
||||
workflow_app_id=workflow_app_id,
|
||||
)
|
||||
|
||||
# Verify all fields are preserved correctly
|
||||
assert isinstance(result, ToolProviderApiEntity)
|
||||
assert result.id == provider_id
|
||||
assert result.author == "another_author"
|
||||
assert result.name == "another_workflow_tool"
|
||||
assert result.description.en_US == "Another description"
|
||||
assert result.description.zh_Hans == "Another description"
|
||||
assert result.icon == {"type": "emoji", "content": "⚙️"}
|
||||
assert result.icon_dark == {"type": "emoji", "content": "🔧"}
|
||||
assert result.label.en_US == "Another Workflow Tool"
|
||||
assert result.label.zh_Hans == "Another Workflow Tool"
|
||||
assert result.type == ToolProviderType.WORKFLOW
|
||||
assert result.workflow_app_id == workflow_app_id
|
||||
assert result.labels == ["automation", "workflow"]
|
||||
assert result.masked_credentials == {}
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is True
|
||||
assert result.plugin_id is None
|
||||
assert result.plugin_unique_identifier is None
|
||||
assert result.tools == []
|
||||
|
||||
|
||||
class TestGetToolProviderIconUrl:
|
||||
def test_builtin_provider_returns_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = "https://app.dify.ai"
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "google", "icon.png")
|
||||
|
||||
assert "/builtin/google/icon" in url
|
||||
assert url.startswith("https://app.dify.ai/console/api/workspaces/current/tool-provider")
|
||||
|
||||
def test_builtin_provider_with_no_console_url(self):
|
||||
with patch(f"{MODULE}.dify_config") as cfg:
|
||||
cfg.CONSOLE_API_URL = None
|
||||
url = ToolTransformService.get_tool_provider_icon_url("builtin", "slack", "icon.png")
|
||||
|
||||
assert "/builtin/slack/icon" in url
|
||||
|
||||
def test_api_provider_parses_json_icon(self):
|
||||
icon_json = '{"background": "#fff", "content": "A"}'
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon_json)
|
||||
assert result == {"background": "#fff", "content": "A"}
|
||||
|
||||
def test_api_provider_returns_dict_icon_directly(self):
|
||||
icon = {"background": "#000", "content": "B"}
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", icon)
|
||||
assert result == icon
|
||||
|
||||
def test_api_provider_returns_fallback_on_invalid_json(self):
|
||||
result = ToolTransformService.get_tool_provider_icon_url("api", "my-api", "not-json")
|
||||
assert result == {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
def test_workflow_provider_behaves_like_api(self):
|
||||
icon = {"background": "#123", "content": "W"}
|
||||
assert ToolTransformService.get_tool_provider_icon_url("workflow", "wf", icon) == icon
|
||||
|
||||
def test_mcp_returns_icon_as_is(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("mcp", "srv", "icon-value") == "icon-value"
|
||||
|
||||
def test_unknown_type_returns_empty(self):
|
||||
assert ToolTransformService.get_tool_provider_icon_url("unknown", "x", "i") == ""
|
||||
|
||||
|
||||
class TestRepackProvider:
|
||||
def test_repacks_dict_provider_icon(self):
|
||||
provider = {"type": "builtin", "name": "google", "icon": "old"}
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/new-url") as mock_fn:
|
||||
ToolTransformService.repack_provider("t1", provider)
|
||||
|
||||
assert provider["icon"] == "/new-url"
|
||||
mock_fn.assert_called_once_with(provider_type="builtin", provider_name="google", icon="old")
|
||||
|
||||
def test_repacks_tool_provider_api_entity_without_plugin(self):
|
||||
entity = MagicMock(spec=ToolProviderApiEntity)
|
||||
entity.plugin_id = None
|
||||
entity.type = ToolProviderType.BUILT_IN
|
||||
entity.name = "slack"
|
||||
entity.icon = "icon.svg"
|
||||
entity.icon_dark = "dark.svg"
|
||||
|
||||
with patch.object(ToolTransformService, "get_tool_provider_icon_url", return_value="/url"):
|
||||
ToolTransformService.repack_provider("t1", entity)
|
||||
|
||||
assert entity.icon == "/url"
|
||||
assert entity.icon_dark == "/url"
|
||||
|
||||
|
||||
class TestConvertMcpSchemaToParameter:
|
||||
def test_simple_object_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Result count"},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
params = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(params) == 2
|
||||
query_param = next(p for p in params if p.name == "query")
|
||||
count_param = next(p for p in params if p.name == "count")
|
||||
assert query_param.required is True
|
||||
assert count_param.required is False
|
||||
assert count_param.type.value == "number"
|
||||
|
||||
def test_float_maps_to_number(self):
|
||||
schema = {"type": "object", "properties": {"rate": {"type": "float"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "number"
|
||||
|
||||
def test_array_type_attaches_input_schema(self):
|
||||
prop = {"type": "array", "description": "Items", "items": {"type": "string"}}
|
||||
schema = {"type": "object", "properties": {"items": prop}, "required": []}
|
||||
param = ToolTransformService.convert_mcp_schema_to_parameter(schema)[0]
|
||||
assert param.input_schema is not None
|
||||
|
||||
def test_non_object_schema_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "string"}) == []
|
||||
|
||||
def test_missing_properties_returns_empty(self):
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter({"type": "object"}) == []
|
||||
|
||||
def test_list_type_uses_first_element(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": ["string", "null"]}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].type.value == "string"
|
||||
|
||||
def test_missing_description_defaults_empty(self):
|
||||
schema = {"type": "object", "properties": {"f": {"type": "string"}}, "required": []}
|
||||
assert ToolTransformService.convert_mcp_schema_to_parameter(schema)[0].llm_description == ""
|
||||
|
||||
|
||||
class TestApiProviderToController:
|
||||
def test_api_key_header_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_header"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_api_key_query_auth(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key_query"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_QUERY)
|
||||
|
||||
def test_legacy_api_key_maps_to_header(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "api_key"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.API_KEY_HEADER)
|
||||
|
||||
def test_unknown_auth_defaults_to_none(self):
|
||||
db_provider = MagicMock()
|
||||
db_provider.credentials = {"auth_type": "something_else"}
|
||||
with patch(f"{MODULE}.ApiToolProviderController") as ctrl_cls:
|
||||
ctrl_cls.from_db.return_value = MagicMock()
|
||||
ToolTransformService.api_provider_to_controller(db_provider)
|
||||
ctrl_cls.from_db.assert_called_once_with(db_provider=db_provider, auth_type=ApiProviderAuthType.NONE)
|
||||
174
api/uv.lock
generated
174
api/uv.lock
generated
@@ -1700,30 +1700,30 @@ requires-dist = [
|
||||
{ name = "httpx-sse", specifier = "~=0.4.0" },
|
||||
{ name = "jieba", specifier = "==0.42.1" },
|
||||
{ name = "json-repair", specifier = ">=0.55.1" },
|
||||
{ name = "langfuse", specifier = "~=2.51.3" },
|
||||
{ name = "langfuse", specifier = ">=3.0.0,<5.0.0" },
|
||||
{ name = "langsmith", specifier = "~=0.7.16" },
|
||||
{ name = "litellm", specifier = "==1.82.6" },
|
||||
{ name = "markdown", specifier = "~=3.10.2" },
|
||||
{ name = "mlflow-skinny", specifier = ">=3.0.0" },
|
||||
{ name = "numpy", specifier = "~=1.26.4" },
|
||||
{ name = "openpyxl", specifier = "~=3.1.5" },
|
||||
{ name = "opentelemetry-api", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-distro", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-exporter-otlp", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-instrumentation", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-instrumentation-celery", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-instrumentation-flask", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-instrumentation-httpx", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-instrumentation-redis", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-api", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-distro", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-exporter-otlp", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-instrumentation", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-instrumentation-celery", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-instrumentation-flask", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-instrumentation-httpx", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-instrumentation-redis", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-propagator-b3", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-proto", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = "==1.28.0" },
|
||||
{ name = "opentelemetry-semantic-conventions", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-util-http", specifier = "==0.49b0" },
|
||||
{ name = "opentelemetry-proto", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-sdk", specifier = "==1.40.0" },
|
||||
{ name = "opentelemetry-semantic-conventions", specifier = "==0.61b0" },
|
||||
{ name = "opentelemetry-util-http", specifier = "==0.61b0" },
|
||||
{ name = "opik", specifier = "~=1.10.37" },
|
||||
{ name = "packaging", specifier = "~=23.2" },
|
||||
{ name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=3.0.1" },
|
||||
@@ -3393,20 +3393,22 @@ sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2
|
||||
|
||||
[[package]]
|
||||
name = "langfuse"
|
||||
version = "2.51.5"
|
||||
version = "4.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "backoff" },
|
||||
{ name = "httpx" },
|
||||
{ name = "idna" },
|
||||
{ name = "openai" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3c/e9/22c9c05d877ab85da6d9008aaa7360f2a9ad58787a8e36e00b1b5be9a990/langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b", size = 117574, upload-time = "2024-10-09T00:59:15.016Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/94/ab00e21fa5977d6b9c68fb3a95de2aa1a1e586964ff2af3e37405bf65d9f/langfuse-4.0.1.tar.gz", hash = "sha256:40a6daf3ab505945c314246d5b577d48fcfde0a47e8c05267ea6bd494ae9608e", size = 272749, upload-time = "2026-03-19T14:03:34.508Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/f7/242a13ca094c78464b7d4df77dfe7d4c44ed77b15fed3d2e3486afa5d2e1/langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb", size = 214281, upload-time = "2024-10-09T00:59:12.596Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/8f/3145ef00940f9c29d7e0200fd040f35616eac21c6ab4610a1ba14f3a04c1/langfuse-4.0.1-py3-none-any.whl", hash = "sha256:e22f49ea31304f97fc31a97c014ba63baa8802d9568295d54f06b00b43c30524", size = 465049, upload-time = "2026-03-19T14:03:32.527Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4200,95 +4202,95 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "deprecated" },
|
||||
{ name = "importlib-metadata" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/79/36/260eaea0f74fdd0c0d8f22ed3a3031109ea1c85531f94f4fde266c29e29a/opentelemetry_api-1.28.0.tar.gz", hash = "sha256:578610bcb8aa5cdcb11169d136cc752958548fb6ccffb0969c1036b0ee9e5353", size = 62803, upload-time = "2024-11-05T19:14:45.497Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2c/1d/4049a9e8698361cc1a1aa03a6c59e4fa4c71e0c0f94a30f988a6876a2ae6/opentelemetry_api-1.40.0.tar.gz", hash = "sha256:159be641c0b04d11e9ecd576906462773eb97ae1b657730f0ecf64d32071569f", size = 70851, upload-time = "2026-03-04T14:17:21.555Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/22/e4/3b25d8b856791c04d8a62b1257b5fc09dc41a057800db06885af8ddcdce1/opentelemetry_api-1.28.0-py3-none-any.whl", hash = "sha256:8457cd2c59ea1bd0988560f021656cecd254ad7ef6be4ba09dbefeca2409ce52", size = 64314, upload-time = "2024-11-05T19:14:21.659Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/bf/93795954016c522008da367da292adceed71cca6ee1717e1d64c83089099/opentelemetry_api-1.40.0-py3-none-any.whl", hash = "sha256:82dd69331ae74b06f6a874704be0cfaa49a1650e1537d4a813b86ecef7d0ecf9", size = 68676, upload-time = "2026-03-04T14:17:01.24Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-distro"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4d/75/7cb7c33899e66bb366d40a889111a78c22df0951038b6699f1663e715a9f/opentelemetry_distro-0.49b0.tar.gz", hash = "sha256:1bafa274f9e83baa0d2a5d47ed02caffcf9bcca60107b389b145400d82b07513", size = 2560, upload-time = "2024-11-05T19:21:39.379Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/00/1f8acc51326956a596fefaf67751380001af36029132a7a07d4debce3c06/opentelemetry_distro-0.61b0.tar.gz", hash = "sha256:975b845f50181ad53753becf4fd4b123b54fa04df5a9d78812264436d6518981", size = 2590, upload-time = "2026-03-04T14:20:12.453Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/db/806172b6a4933966eee518db814b375e620602f7fe776b74ef795690f135/opentelemetry_distro-0.49b0-py3-none-any.whl", hash = "sha256:1af4074702f605ea210753dd41947dc2fd61b39724f23cdcf15d5654867cd3c2", size = 3318, upload-time = "2024-11-05T19:20:34.065Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/2c/efcc995cd7484e6e55b1d26bd7fa6c55ca96bd415ff94310b52c19f330b0/opentelemetry_distro-0.61b0-py3-none-any.whl", hash = "sha256:f21d1ac0627549795d75e332006dd068877f00e461b1b2e8fe4568d6eb7b9590", size = 3349, upload-time = "2026-03-04T14:18:57.788Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-exporter-otlp-proto-grpc" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/eb/16/14e3fc163930ea68f0980a4cdd4ae5796e60aeb898965990e13263d64baf/opentelemetry_exporter_otlp-1.28.0.tar.gz", hash = "sha256:31ae7495831681dd3da34ac457f6970f147465ae4b9aae3a888d7a581c7cd868", size = 6170, upload-time = "2024-11-05T19:14:47.349Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d0/37/b6708e0eff5c5fb9aba2e0ea09f7f3bcbfd12a592d2a780241b5f6014df7/opentelemetry_exporter_otlp-1.40.0.tar.gz", hash = "sha256:7caa0870b95e2fcb59d64e16e2b639ecffb07771b6cd0000b5d12e5e4fef765a", size = 6152, upload-time = "2026-03-04T14:17:23.235Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/82/3f521b3c1f2a411ed60a24a8c9f486c1beeaf8c6c55337c87d3ae1642151/opentelemetry_exporter_otlp-1.28.0-py3-none-any.whl", hash = "sha256:1fd02d70f2c1b7ac5579c81e78de4594b188d3317c8ceb69e8b53900fb7b40fd", size = 7024, upload-time = "2024-11-05T19:14:24.534Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2d/fc/aea77c28d9f3ffef2fdafdc3f4a235aee4091d262ddabd25882f47ce5c5f/opentelemetry_exporter_otlp-1.40.0-py3-none-any.whl", hash = "sha256:48c87e539ec9afb30dc443775a1334cc5487de2f72a770a4c00b1610bf6c697d", size = 7023, upload-time = "2026-03-04T14:17:03.612Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-common"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-proto" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c2/8d/5d411084ac441052f4c9bae03a1aec65ae5d16b439fea7b9c5ac3842c013/opentelemetry_exporter_otlp_proto_common-1.28.0.tar.gz", hash = "sha256:5fa0419b0c8e291180b0fc8430a20dd44a3f3236f8e0827992145914f273ec4f", size = 18505, upload-time = "2024-11-05T19:14:48.204Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/51/bc/1559d46557fe6eca0b46c88d4c2676285f1f3be2e8d06bb5d15fbffc814a/opentelemetry_exporter_otlp_proto_common-1.40.0.tar.gz", hash = "sha256:1cbee86a4064790b362a86601ee7934f368b81cd4cc2f2e163902a6e7818a0fa", size = 20416, upload-time = "2026-03-04T14:17:23.801Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/72/3c44aabc74db325aaba09361b6a0d80f6d601f0ff86ecea8ee655c9538fc/opentelemetry_exporter_otlp_proto_common-1.28.0-py3-none-any.whl", hash = "sha256:467e6437d24e020156dffecece8c0a4471a8a60f6a34afeda7386df31a092410", size = 18403, upload-time = "2024-11-05T19:14:25.798Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/ca/8f122055c97a932311a3f640273f084e738008933503d0c2563cd5d591fc/opentelemetry_exporter_otlp_proto_common-1.40.0-py3-none-any.whl", hash = "sha256:7081ff453835a82417bf38dccf122c827c3cbc94f2079b03bba02a3165f25149", size = 18369, upload-time = "2026-03-04T14:17:04.796Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-grpc"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "deprecated" },
|
||||
{ name = "googleapis-common-protos" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-common" },
|
||||
{ name = "opentelemetry-proto" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/4d/f215162e58041afb4bdf5dbd0d8faf0b7fc9bf7b3d3fc0e44e06f9e7e869/opentelemetry_exporter_otlp_proto_grpc-1.28.0.tar.gz", hash = "sha256:47a11c19dc7f4289e220108e113b7de90d59791cb4c37fc29f69a6a56f2c3735", size = 26237, upload-time = "2024-11-05T19:14:49.026Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8f/7f/b9e60435cfcc7590fa87436edad6822240dddbc184643a2a005301cc31f4/opentelemetry_exporter_otlp_proto_grpc-1.40.0.tar.gz", hash = "sha256:bd4015183e40b635b3dab8da528b27161ba83bf4ef545776b196f0fb4ec47740", size = 25759, upload-time = "2026-03-04T14:17:24.4Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/b5/afabc8106abc0f9cfeecf5b3e682622b3e04bba1d9b967dbfcd91b9c4ebe/opentelemetry_exporter_otlp_proto_grpc-1.28.0-py3-none-any.whl", hash = "sha256:edbdc53e7783f88d4535db5807cb91bd7b1ec9e9b9cdbfee14cd378f29a3b328", size = 18532, upload-time = "2024-11-05T19:14:26.853Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/6f/7ee0980afcbdcd2d40362da16f7f9796bd083bf7f0b8e038abfbc0300f5d/opentelemetry_exporter_otlp_proto_grpc-1.40.0-py3-none-any.whl", hash = "sha256:2aa0ca53483fe0cf6405087a7491472b70335bc5c7944378a0a8e72e86995c52", size = 20304, upload-time = "2026-03-04T14:17:05.942Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-exporter-otlp-proto-http"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "deprecated" },
|
||||
{ name = "googleapis-common-protos" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-exporter-otlp-proto-common" },
|
||||
{ name = "opentelemetry-proto" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "requests" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/555f2845928086cd51aa6941c7a546470805b68ed631ec139ce7d841763d/opentelemetry_exporter_otlp_proto_http-1.28.0.tar.gz", hash = "sha256:d83a9a03a8367ead577f02a64127d827c79567de91560029688dd5cfd0152a8e", size = 15051, upload-time = "2024-11-05T19:14:49.813Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2e/fa/73d50e2c15c56be4d000c98e24221d494674b0cc95524e2a8cb3856d95a4/opentelemetry_exporter_otlp_proto_http-1.40.0.tar.gz", hash = "sha256:db48f5e0f33217588bbc00274a31517ba830da576e59503507c839b38fa0869c", size = 17772, upload-time = "2026-03-04T14:17:25.324Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/ce/80d5adabbf7ab4a0ca7b5e0f4039b24d273be370c3ba85fc05b13794411c/opentelemetry_exporter_otlp_proto_http-1.28.0-py3-none-any.whl", hash = "sha256:e8f3f7961b747edb6b44d51de4901a61e9c01d50debd747b120a08c4996c7e7b", size = 17228, upload-time = "2024-11-05T19:14:28.613Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/3a/8865d6754e61c9fb170cdd530a124a53769ee5f740236064816eb0ca7301/opentelemetry_exporter_otlp_proto_http-1.40.0-py3-none-any.whl", hash = "sha256:a8d1dab28f504c5d96577d6509f80a8150e44e8f45f82cdbe0e34c99ab040069", size = 19960, upload-time = "2026-03-04T14:17:07.153Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4296,14 +4298,14 @@ dependencies = [
|
||||
{ name = "packaging" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/de/6b/6c25b15063c92a011cf3f68375971e2c58a9c764690847edc97df2d94eeb/opentelemetry_instrumentation-0.49b0.tar.gz", hash = "sha256:398a93e0b9dc2d11cc8627e1761665c506fe08c6b2df252a2ab3ade53d751c46", size = 26478, upload-time = "2024-11-05T19:21:41.402Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/da/37/6bf8e66bfcee5d3c6515b79cb2ee9ad05fe573c20f7ceb288d0e7eeec28c/opentelemetry_instrumentation-0.61b0.tar.gz", hash = "sha256:cb21b48db738c9de196eba6b805b4ff9de3b7f187e4bbf9a466fa170514f1fc7", size = 32606, upload-time = "2026-03-04T14:20:16.825Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/93/61/e0d21e958d6072ce25c4f5e26a1d22835fc86f80836660adf6badb6038ce/opentelemetry_instrumentation-0.49b0-py3-none-any.whl", hash = "sha256:68364d73a1ff40894574cbc6138c5f98674790cae1f3b0865e21cf702f24dcb3", size = 30694, upload-time = "2024-11-05T19:20:38.584Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/3e/f6f10f178b6316de67f0dfdbbb699a24fbe8917cf1743c1595fb9dcdd461/opentelemetry_instrumentation-0.61b0-py3-none-any.whl", hash = "sha256:92a93a280e69788e8f88391247cc530fd81f16f2b011979d4d6398f805cfbc63", size = 33448, upload-time = "2026-03-04T14:19:02.447Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-asgi"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "asgiref" },
|
||||
@@ -4312,28 +4314,28 @@ dependencies = [
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "opentelemetry-util-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e8/55/693c3d0938ba5fead5c3aa4ac7022a992b4ff99a8e9979800d0feb843ff4/opentelemetry_instrumentation_asgi-0.49b0.tar.gz", hash = "sha256:959fd9b1345c92f20c6ef1d42f92ef6a76b3c3083fbc4104d59da6859b15b083", size = 24117, upload-time = "2024-11-05T19:21:46.769Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/3e/143cf5c034e58037307e6a24f06e0dd64b2c49ae60a965fc580027581931/opentelemetry_instrumentation_asgi-0.61b0.tar.gz", hash = "sha256:9d08e127244361dc33976d39dd4ca8f128b5aa5a7ae425208400a80a095019b5", size = 26691, upload-time = "2026-03-04T14:20:21.038Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/0b/7900c782a1dfaa584588d724bc3bbdf8405a32497537dd96b3fcbf8461b9/opentelemetry_instrumentation_asgi-0.49b0-py3-none-any.whl", hash = "sha256:722a90856457c81956c88f35a6db606cc7db3231046b708aae2ddde065723dbe", size = 16326, upload-time = "2024-11-05T19:20:46.176Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/78/154470cf9d741a7487fbb5067357b87386475bbb77948a6707cae982e158/opentelemetry_instrumentation_asgi-0.61b0-py3-none-any.whl", hash = "sha256:e4b3ce6b66074e525e717efff20745434e5efd5d9df6557710856fba356da7a4", size = 16980, upload-time = "2026-03-04T14:19:10.894Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-celery"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-instrumentation" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4c/8b/9b8a9dda3ed53354c6f707a45cdb7a4730e1c109b50fc1b413525493f811/opentelemetry_instrumentation_celery-0.49b0.tar.gz", hash = "sha256:afbaee97cc9c75f29bcc9784f16f8e37c415d4fe9b334748c5b90a3d30d12473", size = 14702, upload-time = "2024-11-05T19:21:53.672Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8d/43/e79108a804d16b1dc8ff28edd0e94ac393cf6359a5adcd7cdd2ec4be85f4/opentelemetry_instrumentation_celery-0.61b0.tar.gz", hash = "sha256:0e352a567dc89ed8bc083fc635035ce3c5b96bbbd92831ffd676e93b87f8e94f", size = 14780, upload-time = "2026-03-04T14:20:27.776Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/21/8c/d7d4adb36abbc0e517a69f7a069f32742122ae22d6017202f64570d9f4c5/opentelemetry_instrumentation_celery-0.49b0-py3-none-any.whl", hash = "sha256:38d4a78c78f33020032ef77ef0ead756bdf7838bcfb603de10f5925d39f14929", size = 13749, upload-time = "2024-11-05T19:20:54.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/ed/c05f3c84b455654eb6c047474ffde61ed92efc24030f64213c98bca9d44b/opentelemetry_instrumentation_celery-0.61b0-py3-none-any.whl", hash = "sha256:01235733ff0cdf571cb03b270645abb14b9c8d830313dc5842097ec90146320b", size = 13856, upload-time = "2026-03-04T14:19:20.98Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-fastapi"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4342,14 +4344,14 @@ dependencies = [
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "opentelemetry-util-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fe/bf/8e6d2a4807360f2203192017eb4845f5628dbeaf0597adf3d141cc5c24e1/opentelemetry_instrumentation_fastapi-0.49b0.tar.gz", hash = "sha256:6d14935c41fd3e49328188b6a59dd4c37bd17a66b01c15b0c64afa9714a1f905", size = 19230, upload-time = "2024-11-05T19:21:59.361Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/37/35/aa727bb6e6ef930dcdc96a617b83748fece57b43c47d83ba8d83fbeca657/opentelemetry_instrumentation_fastapi-0.61b0.tar.gz", hash = "sha256:3a24f35b07c557ae1bbc483bf8412221f25d79a405f8b047de8b670722e2fa9f", size = 24800, upload-time = "2026-03-04T14:20:32.759Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/f4/0895b9410c10abf987c90dee1b7688a8f2214a284fe15e575648f6a1473a/opentelemetry_instrumentation_fastapi-0.49b0-py3-none-any.whl", hash = "sha256:646e1b18523cbe6860ae9711eb2c7b9c85466c3c7697cd6b8fb5180d85d3fe6e", size = 12101, upload-time = "2024-11-05T19:21:01.805Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/91/05/acfeb2cccd434242a0a7d0ea29afaf077e04b42b35b485d89aee4e0d9340/opentelemetry_instrumentation_fastapi-0.61b0-py3-none-any.whl", hash = "sha256:a1a844d846540d687d377516b2ff698b51d87c781b59f47c214359c4a241047c", size = 13485, upload-time = "2026-03-04T14:19:30.351Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-flask"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4359,14 +4361,14 @@ dependencies = [
|
||||
{ name = "opentelemetry-util-http" },
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/17/12/dc72873fb1e35699941d8eb6a53ef25e8c5843dea37665dad33bd720f047/opentelemetry_instrumentation_flask-0.49b0.tar.gz", hash = "sha256:f7c5ab67753c4781a2e21c8f43dc5fc02ece74fdd819466c75d025db80aa7576", size = 19176, upload-time = "2024-11-05T19:22:00.816Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d9/33/d6852d8f2c3eef86f2f8c858d6f5315983c7063e07e595519e96d4c31c06/opentelemetry_instrumentation_flask-0.61b0.tar.gz", hash = "sha256:e9faf58dfd9860a1868442d180142645abdafc1a652dd73d469a5efd106a7d49", size = 24071, upload-time = "2026-03-04T14:20:33.437Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/fc/354da8f33ef0daebfc8e4eac995d342ae13a35097bbad512cfe0d2f3c61a/opentelemetry_instrumentation_flask-0.49b0-py3-none-any.whl", hash = "sha256:f3ef330c3cee3e2c161f27f1e7017c8800b9bfb6f9204f2f7bfb0b274874be0e", size = 14582, upload-time = "2024-11-05T19:21:02.793Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3e/41/619f3530324a58491f2d20f216a10dd7393629b29db4610dda642a27f4ed/opentelemetry_instrumentation_flask-0.61b0-py3-none-any.whl", hash = "sha256:e8ce474d7ce543bfbbb3e93f8a6f8263348af9d7b45502f387420cf3afa71253", size = 15996, upload-time = "2026-03-04T14:19:31.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-httpx"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4375,14 +4377,14 @@ dependencies = [
|
||||
{ name = "opentelemetry-util-http" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a0/53/8b5e05e55a513d846ead5afb0509bec37a34a1c3e82f30b13d14156334b1/opentelemetry_instrumentation_httpx-0.49b0.tar.gz", hash = "sha256:07165b624f3e58638cee47ecf1c81939a8c2beb7e42ce9f69e25a9f21dc3f4cf", size = 17750, upload-time = "2024-11-05T19:22:02.911Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cd/2a/e2becd55e33c29d1d9ef76e2579040ed1951cb33bacba259f6aff2fdd2a6/opentelemetry_instrumentation_httpx-0.61b0.tar.gz", hash = "sha256:6569ec097946c5551c2a4252f74c98666addd1bf047c1dde6b4ef426719ff8dd", size = 24104, upload-time = "2026-03-04T14:20:34.752Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/9f/843391c6d645cd4f6914b27bc807fc1ff52b97f84cbe3ca675641976b23f/opentelemetry_instrumentation_httpx-0.49b0-py3-none-any.whl", hash = "sha256:e59e0d2fda5ef841630c68da1d78ff9192f63590a9099f12f0eab614abdf239a", size = 14110, upload-time = "2024-11-05T19:21:04.698Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/af/88/dde310dce56e2d85cf1a09507f5888544955309edc4b8d22971d6d3d1417/opentelemetry_instrumentation_httpx-0.61b0-py3-none-any.whl", hash = "sha256:dee05c93a6593a5dc3ae5d9d5c01df8b4e2c5d02e49275e5558534ee46343d5e", size = 17198, upload-time = "2026-03-04T14:19:33.585Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-redis"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4390,14 +4392,14 @@ dependencies = [
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/5b/1398eb2f92fd76787ccec28d24dc4c7dfaaf97a7557e7729e2f7c2c05d84/opentelemetry_instrumentation_redis-0.49b0.tar.gz", hash = "sha256:922542c3bd192ad4ba74e2c7e0a253c7c58a5cefbd6f89da2aba4d193a974703", size = 11353, upload-time = "2024-11-05T19:22:12.822Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cf/21/26205f89358a5f2be3ee5512d3d3bce16b622977f64aeaa9d3fa8887dd39/opentelemetry_instrumentation_redis-0.61b0.tar.gz", hash = "sha256:ae0fbb56be9a641e621d55b02a7d62977a2c77c5ee760addd79b9b266e46e523", size = 14781, upload-time = "2026-03-04T14:20:45.694Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/24/e4/4f258fef0759629f2e8a0210d5533cfef3ecad69ff35be044637a3e2783e/opentelemetry_instrumentation_redis-0.49b0-py3-none-any.whl", hash = "sha256:b7d8f758bac53e77b7e7ca98ce80f91230577502dacb619ebe8e8b6058042067", size = 12453, upload-time = "2024-11-05T19:21:18.534Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/e1/8f4c8e4194291dbe828aeabe779050a8497b379ad90040a5a0a7074b1d08/opentelemetry_instrumentation_redis-0.61b0-py3-none-any.whl", hash = "sha256:8d4e850bbb5f8eeafa44c0eac3a007990c7125de187bc9c3659e29ff7e091172", size = 15506, upload-time = "2026-03-04T14:19:48.588Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-sqlalchemy"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4406,14 +4408,14 @@ dependencies = [
|
||||
{ name = "packaging" },
|
||||
{ name = "wrapt" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a0/a7/24f6cce3808ae1802dd1b60d752fbab877db5655198929cf4ee8ea416923/opentelemetry_instrumentation_sqlalchemy-0.49b0.tar.gz", hash = "sha256:32658e520fc8b35823c722f5d8831d3a410b76dd2724adb2887befc041ddef04", size = 13194, upload-time = "2024-11-05T19:22:14.92Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/9e/4f/3a325b180944610697a0a926d49d782b41a86120050d44fefb2715b630ac/opentelemetry_instrumentation_sqlalchemy-0.61b0.tar.gz", hash = "sha256:13a3a159a2043a52f0180b3757fbaa26741b0e08abb50deddce4394c118956e6", size = 15343, upload-time = "2026-03-04T14:20:47.648Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/6b/a1a3685fed593282999cdc374ece15efbd56f8d774bd368bf7ff2cf5923c/opentelemetry_instrumentation_sqlalchemy-0.49b0-py3-none-any.whl", hash = "sha256:d854052d2b02cd0562e5628a514c8153fceada7f585137e173165dfd0a46ef6a", size = 13358, upload-time = "2024-11-05T19:21:23.654Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/97/b906a930c6a1a20c53ecc8b58cabc2cdd0ce560a2b5d44259084ffe4333e/opentelemetry_instrumentation_sqlalchemy-0.61b0-py3-none-any.whl", hash = "sha256:f115e0be54116ba4c327b8d7b68db4045ee18d44439d888ab8130a549c50d1c1", size = 14547, upload-time = "2026-03-04T14:19:53.088Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-instrumentation-wsgi"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
@@ -4421,9 +4423,9 @@ dependencies = [
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "opentelemetry-util-http" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/17/2b/91b022b004ac9e9ab0eefd10bc4257975291f88adc81b4ef2c601ddb1adf/opentelemetry_instrumentation_wsgi-0.49b0.tar.gz", hash = "sha256:0812a02e132f8fc3d5c897bba84e530c37b85c315b199bb97ca6508279e7eb23", size = 17733, upload-time = "2024-11-05T19:22:24.3Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/89/e5/189f2845362cfe78e356ba127eab21456309def411c6874aa4800c3de816/opentelemetry_instrumentation_wsgi-0.61b0.tar.gz", hash = "sha256:380f2ae61714e5303275a80b2e14c58571573cd1fddf496d8c39fb9551c5e532", size = 19898, upload-time = "2026-03-04T14:20:54.068Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/1d/59979665778ed8c85bc31c92b75571cd7afb8e3322fb513c87fe1bad6d78/opentelemetry_instrumentation_wsgi-0.49b0-py3-none-any.whl", hash = "sha256:8869ccf96611827e4448417718920e9eec6d25bffb5bf72c7952c7346ec33fbc", size = 13699, upload-time = "2024-11-05T19:21:35.039Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/75/d6b42ba26f3c921be6d01b16561b7bb863f843bad7ac3a5011f62617bcab/opentelemetry_instrumentation_wsgi-0.61b0-py3-none-any.whl", hash = "sha256:bd33b0824166f24134a3400648805e8d2e6a7951f070241294e8b8866611d7fa", size = 14628, upload-time = "2026-03-04T14:20:03.934Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4441,50 +4443,50 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-proto"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "protobuf" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c9/63/ac4cef4d30ea0ca1d2153ad2fc62d91d1cf3b89b0e4e5cbd61a8c567885f/opentelemetry_proto-1.28.0.tar.gz", hash = "sha256:4a45728dfefa33f7908b828b9b7c9f2c6de42a05d5ec7b285662ddae71c4c870", size = 34331, upload-time = "2024-11-05T19:14:59.503Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4c/77/dd38991db037fdfce45849491cb61de5ab000f49824a00230afb112a4392/opentelemetry_proto-1.40.0.tar.gz", hash = "sha256:03f639ca129ba513f5819810f5b1f42bcb371391405d99c168fe6937c62febcd", size = 45667, upload-time = "2026-03-04T14:17:31.194Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/86/94/c0b43d16e1d96ee1e699373aa59f14a3aa2e7126af3f11d6adc5dcc531cd/opentelemetry_proto-1.28.0-py3-none-any.whl", hash = "sha256:d5ad31b997846543b8e15504657d9a8cf1ad3c71dcbbb6c4799b1ab29e38f7f9", size = 55832, upload-time = "2024-11-05T19:14:40.446Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/b2/189b2577dde745b15625b3214302605b1353436219d42b7912e77fa8dc24/opentelemetry_proto-1.40.0-py3-none-any.whl", hash = "sha256:266c4385d88923a23d63e353e9761af0f47a6ed0d486979777fe4de59dc9b25f", size = 72073, upload-time = "2026-03-04T14:17:16.673Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-sdk"
|
||||
version = "1.28.0"
|
||||
version = "1.40.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-semantic-conventions" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0c/5b/a509ccab93eacc6044591d5ec437d8266e76f893d0389bbf7e5592c7da32/opentelemetry_sdk-1.28.0.tar.gz", hash = "sha256:41d5420b2e3fb7716ff4981b510d551eff1fc60eb5a95cf7335b31166812a893", size = 156155, upload-time = "2024-11-05T19:15:00.451Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/fd/3c3125b20ba18ce2155ba9ea74acb0ae5d25f8cd39cfd37455601b7955cc/opentelemetry_sdk-1.40.0.tar.gz", hash = "sha256:18e9f5ec20d859d268c7cb3c5198c8d105d073714db3de50b593b8c1345a48f2", size = 184252, upload-time = "2026-03-04T14:17:31.87Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/fe/c8decbebb5660529f1d6ba65e50a45b1294022dfcba2968fc9c8697c42b2/opentelemetry_sdk-1.28.0-py3-none-any.whl", hash = "sha256:4b37da81d7fad67f6683c4420288c97f4ed0d988845d5886435f428ec4b8429a", size = 118692, upload-time = "2024-11-05T19:14:41.669Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/c5/6a852903d8bfac758c6dc6e9a68b015d3c33f2f1be5e9591e0f4b69c7e0a/opentelemetry_sdk-1.40.0-py3-none-any.whl", hash = "sha256:787d2154a71f4b3d81f20524a8ce061b7db667d24e46753f32a7bc48f1c1f3f1", size = 141951, upload-time = "2026-03-04T14:17:17.961Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-semantic-conventions"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "deprecated" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ee/c8/433b0e54143f8c9369f5c4a7a83e73eec7eb2ee7d0b7e81a9243e78c8e80/opentelemetry_semantic_conventions-0.49b0.tar.gz", hash = "sha256:dbc7b28339e5390b6b28e022835f9bac4e134a80ebf640848306d3c5192557e8", size = 95227, upload-time = "2024-11-05T19:15:01.443Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/6d/c0/4ae7973f3c2cfd2b6e321f1675626f0dab0a97027cc7a297474c9c8f3d04/opentelemetry_semantic_conventions-0.61b0.tar.gz", hash = "sha256:072f65473c5d7c6dc0355b27d6c9d1a679d63b6d4b4b16a9773062cb7e31192a", size = 145755, upload-time = "2026-03-04T14:17:32.664Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/25/05/20104df4ef07d3bf5c3fd6bcc796ef70ab4ea4309378a9ba57bc4b4d01fa/opentelemetry_semantic_conventions-0.49b0-py3-none-any.whl", hash = "sha256:0458117f6ead0b12e3221813e3e511d85698c31901cac84682052adb9c17c7cd", size = 159214, upload-time = "2024-11-05T19:14:43.047Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2", size = 231621, upload-time = "2026-03-04T14:17:19.33Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-util-http"
|
||||
version = "0.49b0"
|
||||
version = "0.61b0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a3/99/377ef446928808211b127b9ab31c348bc465c8da4514ebeec6e4a3de3d21/opentelemetry_util_http-0.49b0.tar.gz", hash = "sha256:02928496afcffd58a7c15baf99d2cedae9b8325a8ac52b0d0877b2e8f936dd1b", size = 7863, upload-time = "2024-11-05T19:22:26.973Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/57/3c/f0196223efc5c4ca19f8fad3d5462b171ac6333013335ce540c01af419e9/opentelemetry_util_http-0.61b0.tar.gz", hash = "sha256:1039cb891334ad2731affdf034d8fb8b48c239af9b6dd295e5fabd07f1c95572", size = 11361, upload-time = "2026-03-04T14:20:57.01Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/66/0e/ab0a89b315d0bacdd355a345bb69b20c50fc1f0804b52b56fe1c35a60e68/opentelemetry_util_http-0.49b0-py3-none-any.whl", hash = "sha256:8661bbd6aea1839badc44de067ec9c15c05eab05f729f496c856c50a1203caf1", size = 6945, upload-time = "2024-11-05T19:21:37.81Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0d/e5/c08aaaf2f64288d2b6ef65741d2de5454e64af3e050f34285fb1907492fe/opentelemetry_util_http-0.61b0-py3-none-any.whl", hash = "sha256:8e715e848233e9527ea47e275659ea60a57a75edf5206a3b937e236a6da5fc33", size = 9281, upload-time = "2026-03-04T14:20:08.364Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5253,11 +5255,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
version = "2.20.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -24,5 +24,4 @@ cp "$MIDDLEWARE_ENV_EXAMPLE" "$MIDDLEWARE_ENV"
|
||||
cd "$ROOT/api"
|
||||
uv sync --group dev
|
||||
|
||||
cd "$ROOT/web"
|
||||
pnpm install
|
||||
pnpm --dir "$ROOT" install
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
cd "$ROOT/docker"
|
||||
docker compose -f docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
ROOT="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
cd "$ROOT/docker"
|
||||
docker compose --env-file middleware.env -f docker-compose.middleware.yaml -p dify up -d
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
set -x
|
||||
|
||||
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
|
||||
cd "$SCRIPT_DIR/../web"
|
||||
ROOT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
pnpm install && pnpm dev:inspect
|
||||
pnpm --dir "$ROOT_DIR" install && pnpm --dir "$ROOT_DIR/web" dev:inspect
|
||||
|
||||
@@ -186,8 +186,10 @@ CELERY_WORKER_CLASS=
|
||||
# it is recommended to set it to 360 to support a longer sse connection time.
|
||||
GUNICORN_TIMEOUT=360
|
||||
|
||||
# The number of Celery workers. The default is 1, and can be set as needed.
|
||||
CELERY_WORKER_AMOUNT=
|
||||
# The number of Celery workers. The default is 4 for development environments
|
||||
# to allow parallel processing of workflows, document indexing, and other async tasks.
|
||||
# Adjust based on your system resources and workload requirements.
|
||||
CELERY_WORKER_AMOUNT=4
|
||||
|
||||
# Flag indicating whether to enable autoscaling of Celery workers.
|
||||
#
|
||||
|
||||
@@ -46,7 +46,7 @@ x-shared-env: &shared-api-worker-env
|
||||
SERVER_WORKER_CONNECTIONS: ${SERVER_WORKER_CONNECTIONS:-10}
|
||||
CELERY_WORKER_CLASS: ${CELERY_WORKER_CLASS:-}
|
||||
GUNICORN_TIMEOUT: ${GUNICORN_TIMEOUT:-360}
|
||||
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-}
|
||||
CELERY_WORKER_AMOUNT: ${CELERY_WORKER_AMOUNT:-4}
|
||||
CELERY_AUTO_SCALE: ${CELERY_AUTO_SCALE:-false}
|
||||
CELERY_MAX_WORKERS: ${CELERY_MAX_WORKERS:-}
|
||||
CELERY_MIN_WORKERS: ${CELERY_MIN_WORKERS:-}
|
||||
|
||||
@@ -19,15 +19,18 @@ It tests:
|
||||
- `uv`
|
||||
- Docker
|
||||
|
||||
Run the following commands from the repository root.
|
||||
|
||||
Install Playwright browsers once:
|
||||
|
||||
```bash
|
||||
cd e2e
|
||||
pnpm install
|
||||
pnpm e2e:install
|
||||
pnpm check
|
||||
pnpm -C e2e e2e:install
|
||||
pnpm -C e2e check
|
||||
```
|
||||
|
||||
`pnpm install` is resolved through the repository workspace and uses the shared root lockfile plus `pnpm-workspace.yaml`.
|
||||
|
||||
Use `pnpm check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package.
|
||||
|
||||
Common commands:
|
||||
@@ -35,20 +38,20 @@ Common commands:
|
||||
```bash
|
||||
# authenticated-only regression (default excludes @fresh)
|
||||
# expects backend API, frontend artifact, and middleware stack to already be running
|
||||
pnpm e2e
|
||||
pnpm -C e2e e2e
|
||||
|
||||
# full reset + fresh install + authenticated scenarios
|
||||
# starts required middleware/dependencies for you
|
||||
pnpm e2e:full
|
||||
pnpm -C e2e e2e:full
|
||||
|
||||
# run a tagged subset
|
||||
pnpm e2e -- --tags @smoke
|
||||
pnpm -C e2e e2e -- --tags @smoke
|
||||
|
||||
# headed browser
|
||||
pnpm e2e:headed -- --tags @smoke
|
||||
pnpm -C e2e e2e:headed -- --tags @smoke
|
||||
|
||||
# slow down browser actions for local debugging
|
||||
E2E_SLOW_MO=500 pnpm e2e:headed -- --tags @smoke
|
||||
E2E_SLOW_MO=500 pnpm -C e2e e2e:headed -- --tags @smoke
|
||||
```
|
||||
|
||||
Frontend artifact behavior:
|
||||
@@ -101,7 +104,7 @@ Because of that, the `@fresh` install scenario only runs in the `pnpm e2e:full*`
|
||||
Reset all persisted E2E state:
|
||||
|
||||
```bash
|
||||
pnpm e2e:reset
|
||||
pnpm -C e2e e2e:reset
|
||||
```
|
||||
|
||||
This removes:
|
||||
@@ -117,7 +120,7 @@ This removes:
|
||||
Start the full middleware stack:
|
||||
|
||||
```bash
|
||||
pnpm e2e:middleware:up
|
||||
pnpm -C e2e e2e:middleware:up
|
||||
```
|
||||
|
||||
Stop the full middleware stack:
|
||||
|
||||
@@ -14,21 +14,11 @@
|
||||
"e2e:reset": "tsx ./scripts/setup.ts reset"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@cucumber/cucumber": "12.7.0",
|
||||
"@playwright/test": "1.51.1",
|
||||
"@types/node": "25.5.0",
|
||||
"tsx": "4.21.0",
|
||||
"typescript": "5.9.3",
|
||||
"vite-plus": "latest"
|
||||
},
|
||||
"engines": {
|
||||
"node": "^22.22.1"
|
||||
},
|
||||
"packageManager": "pnpm@10.32.1",
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"vite": "npm:@voidzero-dev/vite-plus-core@latest",
|
||||
"vitest": "npm:@voidzero-dev/vite-plus-test@latest"
|
||||
}
|
||||
"@cucumber/cucumber": "catalog:",
|
||||
"@playwright/test": "catalog:",
|
||||
"@types/node": "catalog:",
|
||||
"tsx": "catalog:",
|
||||
"typescript": "catalog:",
|
||||
"vite-plus": "catalog:"
|
||||
}
|
||||
}
|
||||
|
||||
2632
e2e/pnpm-lock.yaml
generated
2632
e2e/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
11
package.json
Normal file
11
package.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"name": "dify",
|
||||
"private": true,
|
||||
"engines": {
|
||||
"node": "^22.22.1"
|
||||
},
|
||||
"packageManager": "pnpm@10.33.0",
|
||||
"devDependencies": {
|
||||
"taze": "catalog:"
|
||||
}
|
||||
}
|
||||
4859
web/pnpm-lock.yaml → pnpm-lock.yaml
generated
4859
web/pnpm-lock.yaml → pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
260
pnpm-workspace.yaml
Normal file
260
pnpm-workspace.yaml
Normal file
@@ -0,0 +1,260 @@
|
||||
trustPolicy: no-downgrade
|
||||
minimumReleaseAge: 1440
|
||||
blockExoticSubdeps: true
|
||||
strictDepBuilds: true
|
||||
allowBuilds:
|
||||
'@parcel/watcher': false
|
||||
canvas: false
|
||||
esbuild: false
|
||||
sharp: false
|
||||
packages:
|
||||
- web
|
||||
- e2e
|
||||
- sdks/nodejs-client
|
||||
overrides:
|
||||
"@lexical/code": npm:lexical-code-no-prism@0.41.0
|
||||
"@monaco-editor/loader": 1.7.0
|
||||
"@nolyfill/safe-buffer": npm:safe-buffer@^5.2.1
|
||||
array-includes: npm:@nolyfill/array-includes@^1.0.44
|
||||
array.prototype.findlast: npm:@nolyfill/array.prototype.findlast@^1.0.44
|
||||
array.prototype.findlastindex: npm:@nolyfill/array.prototype.findlastindex@^1.0.44
|
||||
array.prototype.flat: npm:@nolyfill/array.prototype.flat@^1.0.44
|
||||
array.prototype.flatmap: npm:@nolyfill/array.prototype.flatmap@^1.0.44
|
||||
array.prototype.tosorted: npm:@nolyfill/array.prototype.tosorted@^1.0.44
|
||||
assert: npm:@nolyfill/assert@^1.0.26
|
||||
axios: 1.14.0
|
||||
brace-expansion@<2.0.2: 2.0.2
|
||||
canvas: ^3.2.2
|
||||
devalue@<5.3.2: 5.3.2
|
||||
dompurify@>=3.1.3 <=3.3.1: 3.3.2
|
||||
es-iterator-helpers: npm:@nolyfill/es-iterator-helpers@^1.0.21
|
||||
esbuild@<0.27.2: 0.27.2
|
||||
flatted@<=3.4.1: 3.4.2
|
||||
glob@>=10.2.0 <10.5.0: 11.1.0
|
||||
hasown: npm:@nolyfill/hasown@^1.0.44
|
||||
is-arguments: npm:@nolyfill/is-arguments@^1.0.44
|
||||
is-core-module: npm:@nolyfill/is-core-module@^1.0.39
|
||||
is-generator-function: npm:@nolyfill/is-generator-function@^1.0.44
|
||||
is-typed-array: npm:@nolyfill/is-typed-array@^1.0.44
|
||||
isarray: npm:@nolyfill/isarray@^1.0.44
|
||||
object.assign: npm:@nolyfill/object.assign@^1.0.44
|
||||
object.entries: npm:@nolyfill/object.entries@^1.0.44
|
||||
object.fromentries: npm:@nolyfill/object.fromentries@^1.0.44
|
||||
object.groupby: npm:@nolyfill/object.groupby@^1.0.44
|
||||
object.values: npm:@nolyfill/object.values@^1.0.44
|
||||
pbkdf2: ~3.1.5
|
||||
pbkdf2@<3.1.3: 3.1.3
|
||||
picomatch@<2.3.2: 2.3.2
|
||||
picomatch@>=4.0.0 <4.0.4: 4.0.4
|
||||
prismjs: ~1.30
|
||||
prismjs@<1.30.0: 1.30.0
|
||||
rollup@>=4.0.0 <4.59.0: 4.59.0
|
||||
safe-buffer: ^5.2.1
|
||||
safe-regex-test: npm:@nolyfill/safe-regex-test@^1.0.44
|
||||
safer-buffer: npm:@nolyfill/safer-buffer@^1.0.44
|
||||
side-channel: npm:@nolyfill/side-channel@^1.0.44
|
||||
smol-toml@<1.6.1: 1.6.1
|
||||
solid-js: 1.9.11
|
||||
string-width: ~8.2.0
|
||||
string.prototype.includes: npm:@nolyfill/string.prototype.includes@^1.0.44
|
||||
string.prototype.matchall: npm:@nolyfill/string.prototype.matchall@^1.0.44
|
||||
string.prototype.repeat: npm:@nolyfill/string.prototype.repeat@^1.0.44
|
||||
string.prototype.trimend: npm:@nolyfill/string.prototype.trimend@^1.0.44
|
||||
svgo@>=3.0.0 <3.3.3: 3.3.3
|
||||
tar@<=7.5.10: 7.5.11
|
||||
typed-array-buffer: npm:@nolyfill/typed-array-buffer@^1.0.44
|
||||
undici@>=7.0.0 <7.24.0: 7.24.0
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
|
||||
which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44
|
||||
yaml@>=2.0.0 <2.8.3: 2.8.3
|
||||
yauzl@<3.2.1: 3.2.1
|
||||
catalog:
|
||||
"@amplitude/analytics-browser": 2.38.0
|
||||
"@amplitude/plugin-session-replay-browser": 1.27.5
|
||||
"@antfu/eslint-config": 7.7.3
|
||||
"@base-ui/react": 1.3.0
|
||||
"@chromatic-com/storybook": 5.1.1
|
||||
"@cucumber/cucumber": 12.7.0
|
||||
"@egoist/tailwindcss-icons": 1.9.2
|
||||
"@emoji-mart/data": 1.2.1
|
||||
"@eslint-react/eslint-plugin": 3.0.0
|
||||
"@eslint/js": ^10.0.1
|
||||
"@floating-ui/react": 0.27.19
|
||||
"@formatjs/intl-localematcher": 0.8.2
|
||||
"@headlessui/react": 2.2.9
|
||||
"@heroicons/react": 2.2.0
|
||||
"@hono/node-server": 1.19.11
|
||||
"@iconify-json/heroicons": 1.2.3
|
||||
"@iconify-json/ri": 1.2.10
|
||||
"@lexical/code": 0.42.0
|
||||
"@lexical/link": 0.42.0
|
||||
"@lexical/list": 0.42.0
|
||||
"@lexical/react": 0.42.0
|
||||
"@lexical/selection": 0.42.0
|
||||
"@lexical/text": 0.42.0
|
||||
"@lexical/utils": 0.42.0
|
||||
"@mdx-js/loader": 3.1.1
|
||||
"@mdx-js/react": 3.1.1
|
||||
"@mdx-js/rollup": 3.1.1
|
||||
"@monaco-editor/react": 4.7.0
|
||||
"@next/eslint-plugin-next": 16.2.1
|
||||
"@next/mdx": 16.2.1
|
||||
"@orpc/client": 1.13.13
|
||||
"@orpc/contract": 1.13.13
|
||||
"@orpc/openapi-client": 1.13.13
|
||||
"@orpc/tanstack-query": 1.13.13
|
||||
"@playwright/test": 1.58.2
|
||||
"@remixicon/react": 4.9.0
|
||||
"@rgrove/parse-xml": 4.2.0
|
||||
"@sentry/react": 10.46.0
|
||||
"@storybook/addon-docs": 10.3.3
|
||||
"@storybook/addon-links": 10.3.3
|
||||
"@storybook/addon-onboarding": 10.3.3
|
||||
"@storybook/addon-themes": 10.3.3
|
||||
"@storybook/nextjs-vite": 10.3.3
|
||||
"@storybook/react": 10.3.3
|
||||
"@streamdown/math": 1.0.2
|
||||
"@svgdotjs/svg.js": 3.2.5
|
||||
"@t3-oss/env-nextjs": 0.13.11
|
||||
"@tailwindcss/typography": 0.5.19
|
||||
"@tanstack/eslint-plugin-query": 5.95.2
|
||||
"@tanstack/react-devtools": 0.10.0
|
||||
"@tanstack/react-form": 1.28.5
|
||||
"@tanstack/react-form-devtools": 0.2.19
|
||||
"@tanstack/react-query": 5.95.2
|
||||
"@tanstack/react-query-devtools": 5.95.2
|
||||
"@testing-library/dom": 10.4.1
|
||||
"@testing-library/jest-dom": 6.9.1
|
||||
"@testing-library/react": 16.3.2
|
||||
"@testing-library/user-event": 14.6.1
|
||||
"@tsslint/cli": 3.0.2
|
||||
"@tsslint/compat-eslint": 3.0.2
|
||||
"@tsslint/config": 3.0.2
|
||||
"@types/js-cookie": 3.0.6
|
||||
"@types/js-yaml": 4.0.9
|
||||
"@types/negotiator": 0.6.4
|
||||
"@types/node": 25.5.0
|
||||
"@types/postcss-js": 4.1.0
|
||||
"@types/qs": 6.15.0
|
||||
"@types/react": 19.2.14
|
||||
"@types/react-dom": 19.2.3
|
||||
"@types/react-syntax-highlighter": 15.5.13
|
||||
"@types/react-window": 1.8.8
|
||||
"@types/sortablejs": 1.15.9
|
||||
"@typescript-eslint/eslint-plugin": ^8.57.2
|
||||
"@typescript-eslint/parser": 8.57.2
|
||||
"@typescript/native-preview": 7.0.0-dev.20260329.1
|
||||
"@vitejs/plugin-react": 6.0.1
|
||||
"@vitejs/plugin-rsc": 0.5.21
|
||||
"@vitest/coverage-v8": 4.1.2
|
||||
abcjs: 6.6.2
|
||||
agentation: 3.0.2
|
||||
ahooks: 3.9.7
|
||||
autoprefixer: 10.4.27
|
||||
axios: 1.14.0
|
||||
class-variance-authority: 0.7.1
|
||||
clsx: 2.1.1
|
||||
cmdk: 1.1.1
|
||||
code-inspector-plugin: 1.4.5
|
||||
copy-to-clipboard: 3.3.3
|
||||
cron-parser: 5.5.0
|
||||
dayjs: 1.11.20
|
||||
decimal.js: 10.6.0
|
||||
dompurify: 3.3.3
|
||||
echarts: 6.0.0
|
||||
echarts-for-react: 3.0.6
|
||||
elkjs: 0.11.1
|
||||
embla-carousel-autoplay: 8.6.0
|
||||
embla-carousel-react: 8.6.0
|
||||
emoji-mart: 5.6.0
|
||||
es-toolkit: 1.45.1
|
||||
eslint: 10.1.0
|
||||
eslint-markdown: 0.6.0
|
||||
eslint-plugin-better-tailwindcss: 4.3.2
|
||||
eslint-plugin-hyoban: 0.14.1
|
||||
eslint-plugin-markdown-preferences: 0.40.3
|
||||
eslint-plugin-no-barrel-files: 1.2.2
|
||||
eslint-plugin-react-hooks: 7.0.1
|
||||
eslint-plugin-react-refresh: 0.5.2
|
||||
eslint-plugin-sonarjs: 4.0.2
|
||||
eslint-plugin-storybook: 10.3.3
|
||||
fast-deep-equal: 3.1.3
|
||||
foxact: 0.3.0
|
||||
happy-dom: 20.8.9
|
||||
hono: 4.12.9
|
||||
html-entities: 2.6.0
|
||||
html-to-image: 1.11.13
|
||||
husky: 9.1.7
|
||||
i18next: 25.10.10
|
||||
i18next-resources-to-backend: 1.2.1
|
||||
iconify-import-svg: 0.1.2
|
||||
immer: 11.1.4
|
||||
jotai: 2.19.0
|
||||
js-audio-recorder: 1.0.7
|
||||
js-cookie: 3.0.5
|
||||
js-yaml: 4.1.1
|
||||
jsonschema: 1.5.0
|
||||
katex: 0.16.44
|
||||
knip: 6.1.0
|
||||
ky: 1.14.3
|
||||
lamejs: 1.2.1
|
||||
lexical: 0.42.0
|
||||
lint-staged: 16.4.0
|
||||
mermaid: 11.13.0
|
||||
mime: 4.1.0
|
||||
mitt: 3.0.1
|
||||
negotiator: 1.0.0
|
||||
next: 16.2.1
|
||||
next-themes: 0.4.6
|
||||
nuqs: 2.8.9
|
||||
pinyin-pro: 3.28.0
|
||||
postcss: 8.5.8
|
||||
postcss-js: 5.1.0
|
||||
qrcode.react: 4.2.0
|
||||
qs: 6.15.0
|
||||
react: 19.2.4
|
||||
react-18-input-autosize: 3.0.0
|
||||
react-dom: 19.2.4
|
||||
react-easy-crop: 5.5.7
|
||||
react-hotkeys-hook: 5.2.4
|
||||
react-i18next: 16.6.6
|
||||
react-multi-email: 1.0.25
|
||||
react-papaparse: 4.4.0
|
||||
react-pdf-highlighter: 8.0.0-rc.0
|
||||
react-server-dom-webpack: 19.2.4
|
||||
react-sortablejs: 6.1.4
|
||||
react-syntax-highlighter: 15.6.6
|
||||
react-textarea-autosize: 8.5.9
|
||||
react-window: 1.8.11
|
||||
reactflow: 11.11.4
|
||||
remark-breaks: 4.0.0
|
||||
remark-directive: 4.0.0
|
||||
sass: 1.98.0
|
||||
scheduler: 0.27.0
|
||||
sharp: 0.34.5
|
||||
sortablejs: 1.15.7
|
||||
std-semver: 1.0.8
|
||||
storybook: 10.3.3
|
||||
streamdown: 2.5.0
|
||||
string-ts: 2.3.1
|
||||
tailwind-merge: 2.6.1
|
||||
tailwindcss: 3.4.19
|
||||
taze: 19.10.0
|
||||
tldts: 7.0.27
|
||||
tsup: ^8.5.1
|
||||
tsx: 4.21.0
|
||||
typescript: 5.9.3
|
||||
uglify-js: 3.19.3
|
||||
unist-util-visit: 5.1.0
|
||||
use-context-selector: 2.0.0
|
||||
uuid: 13.0.0
|
||||
vinext: 0.0.38
|
||||
vite: npm:@voidzero-dev/vite-plus-core@0.1.14
|
||||
vite-plugin-inspect: 12.0.0-beta.1
|
||||
vite-plus: 0.1.14
|
||||
vitest: npm:@voidzero-dev/vite-plus-test@0.1.14
|
||||
vitest-canvas-mock: 1.1.4
|
||||
zod: 4.3.6
|
||||
zundo: 2.3.0
|
||||
zustand: 5.0.12
|
||||
@@ -100,6 +100,10 @@ Notes:
|
||||
- Chat/completion require a stable `user` identifier in the request payload.
|
||||
- For streaming responses, iterate the returned AsyncIterable. Use `stream.toText()` to collect text.
|
||||
|
||||
## Maintainers
|
||||
|
||||
This package is published from the repository workspace. Install dependencies from the repository root with `pnpm install`, then use `./scripts/publish.sh` for dry runs and publishing so `catalog:` dependencies are resolved before release.
|
||||
|
||||
## License
|
||||
|
||||
This SDK is released under the MIT License.
|
||||
|
||||
@@ -54,24 +54,17 @@
|
||||
"publish:npm": "./scripts/publish.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.13.6"
|
||||
"axios": "catalog:"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^10.0.1",
|
||||
"@types/node": "^25.4.0",
|
||||
"@typescript-eslint/eslint-plugin": "^8.57.0",
|
||||
"@typescript-eslint/parser": "^8.57.0",
|
||||
"@vitest/coverage-v8": "4.0.18",
|
||||
"eslint": "^10.0.3",
|
||||
"tsup": "^8.5.1",
|
||||
"typescript": "^5.9.3",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"flatted@<=3.4.1": "3.4.2",
|
||||
"picomatch@>=4.0.0 <4.0.4": "4.0.4",
|
||||
"rollup@>=4.0.0 <4.59.0": "4.59.0"
|
||||
}
|
||||
"@eslint/js": "catalog:",
|
||||
"@types/node": "catalog:",
|
||||
"@typescript-eslint/eslint-plugin": "catalog:",
|
||||
"@typescript-eslint/parser": "catalog:",
|
||||
"@vitest/coverage-v8": "catalog:",
|
||||
"eslint": "catalog:",
|
||||
"tsup": "catalog:",
|
||||
"typescript": "catalog:",
|
||||
"vitest": "catalog:"
|
||||
}
|
||||
}
|
||||
|
||||
2255
sdks/nodejs-client/pnpm-lock.yaml
generated
2255
sdks/nodejs-client/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -1,2 +0,0 @@
|
||||
onlyBuiltDependencies:
|
||||
- esbuild
|
||||
@@ -5,10 +5,12 @@
|
||||
# A beautiful and reliable script to publish the SDK to npm
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/publish.sh # Normal publish
|
||||
# ./scripts/publish.sh # Normal publish
|
||||
# ./scripts/publish.sh --dry-run # Test without publishing
|
||||
# ./scripts/publish.sh --skip-tests # Skip tests (not recommended)
|
||||
#
|
||||
# This script requires pnpm because the workspace uses catalog: dependencies.
|
||||
#
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
@@ -62,11 +64,27 @@ divider() {
|
||||
echo -e "${DIM}─────────────────────────────────────────────────────────────────${NC}"
|
||||
}
|
||||
|
||||
run_npm() {
|
||||
env \
|
||||
-u npm_config_npm_globalconfig \
|
||||
-u NPM_CONFIG_NPM_GLOBALCONFIG \
|
||||
-u npm_config_verify_deps_before_run \
|
||||
-u NPM_CONFIG_VERIFY_DEPS_BEFORE_RUN \
|
||||
-u npm_config__jsr_registry \
|
||||
-u NPM_CONFIG__JSR_REGISTRY \
|
||||
-u npm_config_catalog \
|
||||
-u NPM_CONFIG_CATALOG \
|
||||
-u npm_config_overrides \
|
||||
-u NPM_CONFIG_OVERRIDES \
|
||||
npm "$@"
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_DIR="$(cd "$SCRIPT_DIR/.." && pwd)"
|
||||
REPO_ROOT="$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null || (cd "$SCRIPT_DIR/../../.." && pwd))"
|
||||
|
||||
DRY_RUN=false
|
||||
SKIP_TESTS=false
|
||||
@@ -123,23 +141,23 @@ main() {
|
||||
error "npm is not installed"
|
||||
exit 1
|
||||
fi
|
||||
NPM_VERSION=$(npm -v)
|
||||
NPM_VERSION=$(run_npm -v)
|
||||
success "npm: v$NPM_VERSION"
|
||||
|
||||
# Check pnpm (optional, for local dev)
|
||||
if command -v pnpm &> /dev/null; then
|
||||
PNPM_VERSION=$(pnpm -v)
|
||||
success "pnpm: v$PNPM_VERSION"
|
||||
else
|
||||
info "pnpm not found (optional)"
|
||||
if ! command -v pnpm &> /dev/null; then
|
||||
error "pnpm is required because this workspace publishes catalog: dependencies"
|
||||
info "Install pnpm with Corepack: corepack enable"
|
||||
exit 1
|
||||
fi
|
||||
PNPM_VERSION=$(pnpm -v)
|
||||
success "pnpm: v$PNPM_VERSION"
|
||||
|
||||
# Check npm login status
|
||||
if ! npm whoami &> /dev/null; then
|
||||
if ! run_npm whoami &> /dev/null; then
|
||||
error "Not logged in to npm. Run 'npm login' first."
|
||||
exit 1
|
||||
fi
|
||||
NPM_USER=$(npm whoami)
|
||||
NPM_USER=$(run_npm whoami)
|
||||
success "Logged in as: ${BOLD}$NPM_USER${NC}"
|
||||
|
||||
# ========================================================================
|
||||
@@ -154,11 +172,11 @@ main() {
|
||||
success "Version: ${BOLD}$PACKAGE_VERSION${NC}"
|
||||
|
||||
# Check if version already exists on npm
|
||||
if npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then
|
||||
if run_npm view "$PACKAGE_NAME@$PACKAGE_VERSION" version &> /dev/null; then
|
||||
error "Version $PACKAGE_VERSION already exists on npm!"
|
||||
echo ""
|
||||
info "Current published versions:"
|
||||
npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5
|
||||
run_npm view "$PACKAGE_NAME" versions --json 2>/dev/null | tail -5
|
||||
echo ""
|
||||
warning "Please update the version in package.json before publishing."
|
||||
exit 1
|
||||
@@ -170,11 +188,7 @@ main() {
|
||||
# ========================================================================
|
||||
step "Step 3/6: Installing dependencies..."
|
||||
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm install --frozen-lockfile 2>/dev/null || pnpm install
|
||||
else
|
||||
npm ci 2>/dev/null || npm install
|
||||
fi
|
||||
pnpm --dir "$REPO_ROOT" install --frozen-lockfile 2>/dev/null || pnpm --dir "$REPO_ROOT" install
|
||||
success "Dependencies installed"
|
||||
|
||||
# ========================================================================
|
||||
@@ -185,11 +199,7 @@ main() {
|
||||
if [[ "$SKIP_TESTS" == true ]]; then
|
||||
warning "Skipping tests (--skip-tests flag)"
|
||||
else
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm test
|
||||
else
|
||||
npm test
|
||||
fi
|
||||
pnpm test
|
||||
success "All tests passed"
|
||||
fi
|
||||
|
||||
@@ -201,11 +211,7 @@ main() {
|
||||
# Clean previous build
|
||||
rm -rf dist
|
||||
|
||||
if command -v pnpm &> /dev/null; then
|
||||
pnpm run build
|
||||
else
|
||||
npm run build
|
||||
fi
|
||||
pnpm run build
|
||||
success "Build completed"
|
||||
|
||||
# Verify build output
|
||||
@@ -223,15 +229,32 @@ main() {
|
||||
# Step 6: Publish
|
||||
# ========================================================================
|
||||
step "Step 6/6: Publishing to npm..."
|
||||
|
||||
|
||||
PACK_DIR="$(mktemp -d)"
|
||||
trap 'rm -rf "$PACK_DIR"' EXIT
|
||||
|
||||
pnpm pack --pack-destination "$PACK_DIR" >/dev/null
|
||||
PACKAGE_TARBALL="$(find "$PACK_DIR" -maxdepth 1 -name '*.tgz' | head -n 1)"
|
||||
|
||||
if [[ -z "$PACKAGE_TARBALL" ]]; then
|
||||
error "Pack failed - no tarball generated"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if tar -xOf "$PACKAGE_TARBALL" package/package.json | grep -q '"catalog:'; then
|
||||
error "Packed manifest still contains catalog: references"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
divider
|
||||
echo -e "${CYAN}Package contents:${NC}"
|
||||
npm pack --dry-run 2>&1 | head -30
|
||||
tar -tzf "$PACKAGE_TARBALL" | head -30
|
||||
divider
|
||||
|
||||
if [[ "$DRY_RUN" == true ]]; then
|
||||
warning "DRY-RUN: Skipping actual publish"
|
||||
echo ""
|
||||
info "Packed artifact: $PACKAGE_TARBALL"
|
||||
info "To publish for real, run without --dry-run flag"
|
||||
else
|
||||
echo ""
|
||||
@@ -239,7 +262,7 @@ main() {
|
||||
echo -e "${DIM}Press Enter to continue, or Ctrl+C to cancel...${NC}"
|
||||
read -r
|
||||
|
||||
npm publish --access public
|
||||
pnpm publish --access public --no-git-checks
|
||||
|
||||
echo ""
|
||||
success "🎉 Successfully published ${BOLD}$PACKAGE_NAME@$PACKAGE_VERSION${NC} to npm!"
|
||||
|
||||
@@ -10,6 +10,7 @@ export default defineConfig({
|
||||
// We can not upgrade these yet
|
||||
'tailwind-merge',
|
||||
'tailwindcss',
|
||||
'typescript',
|
||||
],
|
||||
|
||||
write: true,
|
||||
@@ -1,32 +0,0 @@
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# Logs
|
||||
logs
|
||||
*.log*
|
||||
|
||||
# node
|
||||
node_modules
|
||||
dist
|
||||
build
|
||||
coverage
|
||||
.husky
|
||||
.next
|
||||
.pnpm-store
|
||||
|
||||
# vscode
|
||||
.vscode
|
||||
|
||||
# webstorm
|
||||
.idea
|
||||
*.iml
|
||||
*.iws
|
||||
*.ipr
|
||||
|
||||
|
||||
# Jetbrains
|
||||
.idea
|
||||
|
||||
# git
|
||||
.git
|
||||
.gitignore
|
||||
@@ -19,21 +19,27 @@ ENV NEXT_PUBLIC_BASE_PATH="$NEXT_PUBLIC_BASE_PATH"
|
||||
# install packages
|
||||
FROM base AS packages
|
||||
|
||||
WORKDIR /app/web
|
||||
WORKDIR /app
|
||||
|
||||
COPY package.json pnpm-lock.yaml /app/web/
|
||||
COPY package.json pnpm-lock.yaml pnpm-workspace.yaml /app/
|
||||
COPY web/package.json /app/web/
|
||||
COPY e2e/package.json /app/e2e/
|
||||
COPY sdks/nodejs-client/package.json /app/sdks/nodejs-client/
|
||||
|
||||
# Use packageManager from package.json
|
||||
RUN corepack install
|
||||
|
||||
RUN pnpm install --frozen-lockfile
|
||||
# Install only the web workspace to keep image builds from pulling in
|
||||
# unrelated workspace dependencies such as e2e tooling.
|
||||
RUN pnpm install --filter ./web... --frozen-lockfile
|
||||
|
||||
# build resources
|
||||
FROM base AS builder
|
||||
WORKDIR /app/web
|
||||
COPY --from=packages /app/web/ .
|
||||
WORKDIR /app
|
||||
COPY --from=packages /app/ .
|
||||
COPY . .
|
||||
|
||||
WORKDIR /app/web
|
||||
ENV NODE_OPTIONS="--max-old-space-size=4096"
|
||||
RUN pnpm build
|
||||
|
||||
@@ -64,13 +70,13 @@ RUN addgroup -S -g ${dify_uid} dify && \
|
||||
chown -R dify:dify /app
|
||||
|
||||
|
||||
WORKDIR /app/web
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder --chown=dify:dify /app/web/public ./public
|
||||
COPY --from=builder --chown=dify:dify /app/web/public ./web/public
|
||||
COPY --from=builder --chown=dify:dify /app/web/.next/standalone ./
|
||||
COPY --from=builder --chown=dify:dify /app/web/.next/static ./.next/static
|
||||
COPY --from=builder --chown=dify:dify /app/web/.next/static ./web/.next/static
|
||||
|
||||
COPY --chown=dify:dify --chmod=755 docker/entrypoint.sh ./entrypoint.sh
|
||||
COPY --chown=dify:dify --chmod=755 web/docker/entrypoint.sh ./entrypoint.sh
|
||||
|
||||
ARG COMMIT_SHA
|
||||
ENV COMMIT_SHA=${COMMIT_SHA}
|
||||
|
||||
34
web/Dockerfile.dockerignore
Normal file
34
web/Dockerfile.dockerignore
Normal file
@@ -0,0 +1,34 @@
|
||||
**
|
||||
!package.json
|
||||
!pnpm-lock.yaml
|
||||
!pnpm-workspace.yaml
|
||||
!.nvmrc
|
||||
!web/
|
||||
!web/**
|
||||
!e2e/
|
||||
!e2e/package.json
|
||||
!sdks/
|
||||
!sdks/nodejs-client/
|
||||
!sdks/nodejs-client/package.json
|
||||
|
||||
.git
|
||||
node_modules
|
||||
.pnpm-store
|
||||
web/.env
|
||||
web/.env.*
|
||||
web/logs
|
||||
web/*.log*
|
||||
web/node_modules
|
||||
web/dist
|
||||
web/build
|
||||
web/coverage
|
||||
web/.husky
|
||||
web/.next
|
||||
web/.pnpm-store
|
||||
web/.vscode
|
||||
web/.idea
|
||||
web/*.iml
|
||||
web/*.iws
|
||||
web/*.ipr
|
||||
e2e/node_modules
|
||||
sdks/nodejs-client/node_modules
|
||||
@@ -24,18 +24,24 @@ For example, use `vp install` instead of `pnpm install` and `vp test` instead of
|
||||
>
|
||||
> Learn more: [Corepack]
|
||||
|
||||
Run the following commands from the repository root.
|
||||
|
||||
First, install the dependencies:
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> JavaScript dependencies are managed by the workspace files at the repository root: `package.json`, `pnpm-lock.yaml`, `pnpm-workspace.yaml`, and `.nvmrc`.
|
||||
> Install dependencies from the repository root, then run frontend scripts from `web/`.
|
||||
|
||||
Then, configure the environment variables.
|
||||
Create a file named `.env.local` in the current directory and copy the contents from `.env.example`.
|
||||
Create `web/.env.local` and copy the contents from `web/.env.example`.
|
||||
Modify the values of these environment variables according to your requirements:
|
||||
|
||||
```bash
|
||||
cp .env.example .env.local
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -46,16 +52,16 @@ cp .env.example .env.local
|
||||
Finally, run the development server:
|
||||
|
||||
```bash
|
||||
pnpm run dev
|
||||
pnpm -C web run dev
|
||||
# or if you are using vinext which provides a better development experience
|
||||
pnpm run dev:vinext
|
||||
pnpm -C web run dev:vinext
|
||||
# (optional) start the dev proxy server so that you can use online API in development
|
||||
pnpm run dev:proxy
|
||||
pnpm -C web run dev:proxy
|
||||
```
|
||||
|
||||
Open <http://localhost:3000> with your browser to see the result.
|
||||
|
||||
You can start editing the file under folder `app`.
|
||||
You can start editing the files under `web/app`.
|
||||
The page auto-updates as you edit the file.
|
||||
|
||||
## Deploy
|
||||
@@ -65,19 +71,25 @@ The page auto-updates as you edit the file.
|
||||
First, build the app for production:
|
||||
|
||||
```bash
|
||||
pnpm run build
|
||||
pnpm -C web run build
|
||||
```
|
||||
|
||||
Then, start the server:
|
||||
|
||||
```bash
|
||||
pnpm run start
|
||||
pnpm -C web run start
|
||||
```
|
||||
|
||||
If you build the Docker image manually, use the repository root as the build context:
|
||||
|
||||
```bash
|
||||
docker build -f web/Dockerfile -t dify-web .
|
||||
```
|
||||
|
||||
If you want to customize the host and port:
|
||||
|
||||
```bash
|
||||
pnpm run start --port=3001 --host=0.0.0.0
|
||||
pnpm -C web run start --port=3001 --host=0.0.0.0
|
||||
```
|
||||
|
||||
## Storybook
|
||||
@@ -87,7 +99,7 @@ This project uses [Storybook] for UI component development.
|
||||
To start the storybook server, run:
|
||||
|
||||
```bash
|
||||
pnpm storybook
|
||||
pnpm -C web storybook
|
||||
```
|
||||
|
||||
Open <http://localhost:6006> with your browser to see the result.
|
||||
@@ -112,7 +124,7 @@ We use [Vitest] and [React Testing Library] for Unit Testing.
|
||||
Run test:
|
||||
|
||||
```bash
|
||||
pnpm test
|
||||
pnpm -C web test
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import type { FeatureStoreState } from '@/app/components/base/features/store'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
@@ -28,7 +27,7 @@ type SetupOptions = {
|
||||
}
|
||||
|
||||
let mockFeatureStoreState: FeatureStoreState
|
||||
let mockSetFeatures: Mock
|
||||
let mockSetFeatures = vi.fn()
|
||||
const mockStore = {
|
||||
getState: vi.fn<() => FeatureStoreState>(() => mockFeatureStoreState),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import type { FeatureStoreState } from '@/app/components/base/features/store'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
@@ -28,7 +27,7 @@ type SetupOptions = {
|
||||
}
|
||||
|
||||
let mockFeatureStoreState: FeatureStoreState
|
||||
let mockSetFeatures: Mock
|
||||
let mockSetFeatures = vi.fn()
|
||||
const mockStore = {
|
||||
getState: vi.fn<() => FeatureStoreState>(() => mockFeatureStoreState),
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import type { ModelConfig, PromptVariable } from '@/models/debug'
|
||||
import type { ToolItem } from '@/types/app'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
@@ -74,10 +73,10 @@ type MockContext = {
|
||||
history: boolean
|
||||
query: boolean
|
||||
}
|
||||
showHistoryModal: Mock
|
||||
showHistoryModal: () => void
|
||||
modelConfig: ModelConfig
|
||||
setModelConfig: Mock
|
||||
setPrevPromptConfig: Mock
|
||||
setModelConfig: (modelConfig: ModelConfig) => void
|
||||
setPrevPromptConfig: (configs: ModelConfig['configs']) => void
|
||||
}
|
||||
|
||||
const createPromptVariable = (overrides: Partial<PromptVariable> = {}): PromptVariable => ({
|
||||
@@ -142,7 +141,7 @@ const createContextValue = (overrides: Partial<MockContext> = {}): MockContext =
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const mockUseContext = useContextSelector.useContext as Mock
|
||||
const mockUseContext = vi.mocked(useContextSelector.useContext)
|
||||
|
||||
const renderConfig = (contextOverrides: Partial<MockContext> = {}) => {
|
||||
const contextValue = createContextValue(contextOverrides)
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { useDSLDragDrop } from '../use-dsl-drag-drop'
|
||||
|
||||
describe('useDSLDragDrop', () => {
|
||||
let container: HTMLDivElement
|
||||
let mockOnDSLFileDropped: Mock
|
||||
let mockOnDSLFileDropped = vi.fn<(file: File) => void>()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
container = document.createElement('div')
|
||||
document.body.appendChild(container)
|
||||
mockOnDSLFileDropped = vi.fn()
|
||||
mockOnDSLFileDropped = vi.fn<(file: File) => void>()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -11,15 +11,15 @@ type EmblaEventName = 'reInit' | 'select'
|
||||
type EmblaListener = (api: MockEmblaApi | undefined) => void
|
||||
|
||||
type MockEmblaApi = {
|
||||
scrollPrev: Mock
|
||||
scrollNext: Mock
|
||||
scrollTo: Mock
|
||||
selectedScrollSnap: Mock
|
||||
canScrollPrev: Mock
|
||||
canScrollNext: Mock
|
||||
slideNodes: Mock
|
||||
on: Mock
|
||||
off: Mock
|
||||
scrollPrev: Mock<() => void>
|
||||
scrollNext: Mock<() => void>
|
||||
scrollTo: Mock<(index: number) => void>
|
||||
selectedScrollSnap: Mock<() => number>
|
||||
canScrollPrev: Mock<() => boolean>
|
||||
canScrollNext: Mock<() => boolean>
|
||||
slideNodes: Mock<() => HTMLDivElement[]>
|
||||
on: Mock<(event: EmblaEventName, callback: EmblaListener) => void>
|
||||
off: Mock<(event: EmblaEventName, callback: EmblaListener) => void>
|
||||
}
|
||||
|
||||
let mockCanScrollPrev = false
|
||||
@@ -33,19 +33,19 @@ const mockCarouselRef = vi.fn()
|
||||
const mockedUseEmblaCarousel = vi.mocked(useEmblaCarousel)
|
||||
|
||||
const createMockEmblaApi = (): MockEmblaApi => ({
|
||||
scrollPrev: vi.fn(),
|
||||
scrollNext: vi.fn(),
|
||||
scrollTo: vi.fn(),
|
||||
selectedScrollSnap: vi.fn(() => mockSelectedIndex),
|
||||
canScrollPrev: vi.fn(() => mockCanScrollPrev),
|
||||
canScrollNext: vi.fn(() => mockCanScrollNext),
|
||||
slideNodes: vi.fn(() =>
|
||||
Array.from({ length: mockSlideCount }).fill(document.createElement('div')),
|
||||
scrollPrev: vi.fn<() => void>(),
|
||||
scrollNext: vi.fn<() => void>(),
|
||||
scrollTo: vi.fn<(index: number) => void>(),
|
||||
selectedScrollSnap: vi.fn<() => number>(() => mockSelectedIndex),
|
||||
canScrollPrev: vi.fn<() => boolean>(() => mockCanScrollPrev),
|
||||
canScrollNext: vi.fn<() => boolean>(() => mockCanScrollNext),
|
||||
slideNodes: vi.fn<() => HTMLDivElement[]>(() =>
|
||||
Array.from({ length: mockSlideCount }, () => document.createElement('div')),
|
||||
),
|
||||
on: vi.fn((event: EmblaEventName, callback: EmblaListener) => {
|
||||
on: vi.fn<(event: EmblaEventName, callback: EmblaListener) => void>((event, callback) => {
|
||||
listeners[event].push(callback)
|
||||
}),
|
||||
off: vi.fn((event: EmblaEventName, callback: EmblaListener) => {
|
||||
off: vi.fn<(event: EmblaEventName, callback: EmblaListener) => void>((event, callback) => {
|
||||
listeners[event] = listeners[event].filter(listener => listener !== callback)
|
||||
}),
|
||||
})
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { Mock } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import UpgradeBtn from '../index'
|
||||
@@ -14,14 +13,16 @@ vi.mock('@/context/modal-context', () => ({
|
||||
}),
|
||||
}))
|
||||
|
||||
type GtagHandler = (command: string, action: string, payload: { loc: string }) => void
|
||||
|
||||
// Typed window accessor for gtag tracking tests
|
||||
const gtagWindow = window as unknown as Record<string, Mock | undefined>
|
||||
let mockGtag: Mock | undefined
|
||||
const gtagWindow = window as unknown as { gtag?: GtagHandler }
|
||||
let mockGtag = vi.fn<GtagHandler>()
|
||||
|
||||
describe('UpgradeBtn', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockGtag = vi.fn()
|
||||
mockGtag = vi.fn<GtagHandler>()
|
||||
gtagWindow.gtag = mockGtag
|
||||
})
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user