mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 20:06:48 +00:00
Compare commits
10 Commits
34028
...
codex/i18n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c4a0beebe | ||
|
|
45f2d03911 | ||
|
|
a1171877a4 | ||
|
|
f06cc339cc | ||
|
|
6bf8982559 | ||
|
|
364d7ebc40 | ||
|
|
7cc81e9a43 | ||
|
|
3409c519e2 | ||
|
|
5851b42af3 | ||
|
|
c5eae67ac9 |
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: make type-check
|
||||
run: make type-check-core
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
486
.github/workflows/translate-i18n-claude.yml
vendored
486
.github/workflows/translate-i18n-claude.yml
vendored
@@ -1,12 +1,10 @@
|
||||
name: Translate i18n Files with Claude Code
|
||||
|
||||
# Note: claude-code-action doesn't support push events directly.
|
||||
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
|
||||
# See: https://github.com/langgenius/dify/issues/30743
|
||||
|
||||
on:
|
||||
repository_dispatch:
|
||||
types: [i18n-sync]
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
files:
|
||||
@@ -18,9 +16,9 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
mode:
|
||||
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
|
||||
description: 'Sync mode: incremental (compare with previous en-US revision) or full (sync all keys in scope)'
|
||||
required: false
|
||||
default: 'incremental'
|
||||
default: incremental
|
||||
type: choice
|
||||
options:
|
||||
- incremental
|
||||
@@ -30,6 +28,10 @@ permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
concurrency:
|
||||
group: translate-i18n-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
translate:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
@@ -51,380 +53,132 @@ jobs:
|
||||
- name: Setup web environment
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect_changes
|
||||
- name: Prepare sync context
|
||||
id: context
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
# Manual trigger
|
||||
if [ -n "${{ github.event.inputs.files }}" ]; then
|
||||
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
# Get all JSON files in en-US directory
|
||||
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
|
||||
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
|
||||
|
||||
# For manual trigger with incremental mode, get diff from last commit
|
||||
# For full mode, we'll do a complete check anyway
|
||||
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
|
||||
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
|
||||
if [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
fi
|
||||
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
|
||||
# Triggered by push via trigger-i18n-sync.yml workflow
|
||||
# Validate required payload fields
|
||||
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
|
||||
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
|
||||
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
|
||||
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
|
||||
|
||||
# Decode the base64-encoded diff from the trigger workflow
|
||||
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
|
||||
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
|
||||
echo "Warning: Failed to decode base64 diff payload" >&2
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
elif [ -s /tmp/i18n-diff.txt ]; then
|
||||
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "" > /tmp/i18n-diff.txt
|
||||
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
BASE_SHA="${{ github.event.before }}"
|
||||
if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
fi
|
||||
HEAD_SHA="${{ github.sha }}"
|
||||
CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
TARGET_LANGS=""
|
||||
SYNC_MODE="incremental"
|
||||
else
|
||||
echo "Unsupported event type: ${{ github.event_name }}"
|
||||
exit 1
|
||||
HEAD_SHA=$(git rev-parse HEAD)
|
||||
TARGET_LANGS="${{ github.event.inputs.languages }}"
|
||||
SYNC_MODE="${{ github.event.inputs.mode || 'incremental' }}"
|
||||
if [ -n "${{ github.event.inputs.files }}" ]; then
|
||||
CHANGED_FILES="${{ github.event.inputs.files }}"
|
||||
else
|
||||
CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
if [ "$SYNC_MODE" = "incremental" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true)
|
||||
else
|
||||
BASE_SHA=""
|
||||
fi
|
||||
fi
|
||||
|
||||
# Truncate diff if too large (keep first 50KB)
|
||||
if [ -f /tmp/i18n-diff.txt ]; then
|
||||
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
|
||||
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
|
||||
FILE_ARGS=""
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
FILE_ARGS="--file $CHANGED_FILES"
|
||||
fi
|
||||
|
||||
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
|
||||
LANG_ARGS=""
|
||||
if [ -n "$TARGET_LANGS" ]; then
|
||||
LANG_ARGS="--lang $TARGET_LANGS"
|
||||
fi
|
||||
|
||||
{
|
||||
echo "BASE_SHA=$BASE_SHA"
|
||||
echo "HEAD_SHA=$HEAD_SHA"
|
||||
echo "CHANGED_FILES=$CHANGED_FILES"
|
||||
echo "TARGET_LANGS=$TARGET_LANGS"
|
||||
echo "SYNC_MODE=$SYNC_MODE"
|
||||
echo "FILE_ARGS=$FILE_ARGS"
|
||||
echo "LANG_ARGS=$LANG_ARGS"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
echo "Files: ${CHANGED_FILES:-<none>}"
|
||||
echo "Languages: ${TARGET_LANGS:-<all supported>}"
|
||||
echo "Mode: $SYNC_MODE"
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77
|
||||
if: steps.context.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# Allow github-actions bot to trigger this workflow via repository_dispatch
|
||||
# See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
allowed_bots: 'github-actions[bot]'
|
||||
prompt: |
|
||||
You are a professional i18n synchronization engineer for the Dify project.
|
||||
Your task is to keep all language translations in sync with the English source (en-US).
|
||||
You are the i18n sync agent for the Dify repository.
|
||||
Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result.
|
||||
|
||||
## CRITICAL TOOL RESTRICTIONS
|
||||
- Use **Read** tool to read files (NOT cat or bash)
|
||||
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
|
||||
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
|
||||
- Run bash commands ONE BY ONE, never combine with && or ||
|
||||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
|
||||
Use absolute paths at all times:
|
||||
- Repo root: `${{ github.workspace }}`
|
||||
- Web directory: `${{ github.workspace }}/web`
|
||||
- Language config: `${{ github.workspace }}/web/i18n-config/languages.ts`
|
||||
|
||||
## WORKING DIRECTORY & ABSOLUTE PATHS
|
||||
Claude Code sandbox working directory may vary. Always use absolute paths:
|
||||
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
|
||||
- For git: `git -C ${{ github.workspace }} <command>`
|
||||
- For gh: `gh --repo ${{ github.repository }} <command>`
|
||||
- For file paths: `${{ github.workspace }}/web/i18n/`
|
||||
Inputs:
|
||||
- Files in scope: `${{ steps.context.outputs.CHANGED_FILES }}`
|
||||
- Target languages: `${{ steps.context.outputs.TARGET_LANGS }}`
|
||||
- Sync mode: `${{ steps.context.outputs.SYNC_MODE }}`
|
||||
- Base SHA: `${{ steps.context.outputs.BASE_SHA }}`
|
||||
- Head SHA: `${{ steps.context.outputs.HEAD_SHA }}`
|
||||
- Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}`
|
||||
- Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}`
|
||||
|
||||
## EFFICIENCY RULES
|
||||
- **ONE Edit per language file** - batch all key additions into a single Edit
|
||||
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
|
||||
- Translate ALL keys for a language mentally first, then do ONE Edit
|
||||
|
||||
## Context
|
||||
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
|
||||
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
|
||||
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
|
||||
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
|
||||
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
|
||||
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
|
||||
|
||||
## CRITICAL DESIGN: Verify First, Then Sync
|
||||
|
||||
You MUST follow this three-phase approach:
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 1.1: Analyze Git Diff (for incremental mode)
|
||||
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
|
||||
|
||||
Parse the diff to categorize changes:
|
||||
- Lines with `+` (not `+++`): Added or modified values
|
||||
- Lines with `-` (not `---`): Removed or old values
|
||||
- Identify specific keys for each category:
|
||||
* ADD: Keys that appear only in `+` lines (new keys)
|
||||
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
|
||||
* DELETE: Keys that appear only in `-` lines (removed keys)
|
||||
|
||||
### Step 1.2: Read Language Configuration
|
||||
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
|
||||
Extract all languages with `supported: true`.
|
||||
|
||||
### Step 1.3: Run i18n:check for Each Language
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
|
||||
```
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web run i18n:check
|
||||
```
|
||||
|
||||
This will report:
|
||||
- Missing keys (need to ADD)
|
||||
- Extra keys (need to DELETE)
|
||||
|
||||
### Step 1.4: Generate Change Report
|
||||
|
||||
Create a structured report identifying:
|
||||
```
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ I18N SYNC CHANGE REPORT ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ Files to process: [list] ║
|
||||
║ Languages to sync: [list] ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ ADD (New Keys): ║
|
||||
║ - [filename].[key]: "English value" ║
|
||||
║ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ UPDATE (Modified Keys - MUST re-translate): ║
|
||||
║ - [filename].[key]: "Old value" → "New value" ║
|
||||
║ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ DELETE (Extra Keys): ║
|
||||
║ - [language]/[filename].[key] ║
|
||||
║ ... ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
```
|
||||
|
||||
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
|
||||
the English value changed. These MUST be re-translated even if target
|
||||
language already has a translation (it's now stale!).
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 2: SYNC - Execute Changes Based on Report ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 2.1: Process ADD Operations (BATCH per language file)
|
||||
|
||||
**CRITICAL WORKFLOW for efficiency:**
|
||||
1. First, translate ALL new keys for ALL languages mentally
|
||||
2. Then, for EACH language file, do ONE Edit operation:
|
||||
- Read the file once
|
||||
- Insert ALL new keys at the beginning (right after the opening `{`)
|
||||
- Don't worry about alphabetical order - lint:fix will sort them later
|
||||
|
||||
Example Edit (adding 3 keys to zh-Hans/app.json):
|
||||
```
|
||||
old_string: '{\n "accessControl"'
|
||||
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
|
||||
```
|
||||
|
||||
**IMPORTANT**:
|
||||
- ONE Edit per language file (not one Edit per key!)
|
||||
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
|
||||
|
||||
### Step 2.2: Process UPDATE Operations
|
||||
|
||||
**IMPORTANT: Special handling for zh-Hans and ja-JP**
|
||||
If zh-Hans or ja-JP files were ALSO modified in the same push:
|
||||
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
|
||||
- If found, it means someone manually translated them. Apply these rules:
|
||||
|
||||
1. **Missing keys**: Still ADD them (completeness required)
|
||||
2. **Existing translations**: Compare with the NEW English value:
|
||||
- If translation is **completely wrong** or **unrelated** → Update it
|
||||
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
|
||||
- When in doubt, **keep the manual translation**
|
||||
|
||||
Example:
|
||||
- English changed: "Save" → "Save Changes"
|
||||
- Manual translation: "保存更改" → Keep it (correct meaning)
|
||||
- Manual translation: "删除" → Update it (completely wrong)
|
||||
|
||||
For other languages:
|
||||
Use Edit tool to replace the old value with the new translation.
|
||||
You can batch multiple updates in one Edit if they are adjacent.
|
||||
|
||||
### Step 2.3: Process DELETE Operations
|
||||
For extra keys reported by i18n:check:
|
||||
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
|
||||
- Or manually remove from target language JSON files
|
||||
|
||||
## Translation Guidelines
|
||||
|
||||
- PRESERVE all placeholders exactly as-is:
|
||||
- `{{variable}}` - Mustache interpolation
|
||||
- `${variable}` - Template literal
|
||||
- `<tag>content</tag>` - HTML tags
|
||||
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
|
||||
|
||||
**CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them**
|
||||
|
||||
✅ CORRECT examples:
|
||||
- English: "{{count}} items" → Japanese: "{{count}} 個のアイテム"
|
||||
- English: "{{name}} updated" → Korean: "{{name}} 업데이트됨"
|
||||
- English: "<email>{{email}}</email>" → Chinese: "<email>{{email}}</email>"
|
||||
- English: "<CustomLink>Marketplace</CustomLink>" → Japanese: "<CustomLink>マーケットプレイス</CustomLink>"
|
||||
|
||||
❌ WRONG examples (NEVER do this - will break the application):
|
||||
- "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese)
|
||||
- "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean)
|
||||
- "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese)
|
||||
- "<email>" → "<メール>" ❌ (tag name translated)
|
||||
- "<CustomLink>" → "<自定义链接>" ❌ (component name translated)
|
||||
|
||||
- Use appropriate language register (formal/informal) based on existing translations
|
||||
- Match existing translation style in each language
|
||||
- Technical terms: check existing conventions per language
|
||||
- For CJK languages: no spaces between characters unless necessary
|
||||
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
|
||||
|
||||
## Output Format Requirements
|
||||
- Alphabetical key ordering (if original file uses it)
|
||||
- 2-space indentation
|
||||
- Trailing newline at end of file
|
||||
- Valid JSON (use proper escaping for special characters)
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
### Step 3.1: Run Lint Fix (IMPORTANT!)
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
|
||||
```
|
||||
This ensures:
|
||||
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
|
||||
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
|
||||
- No extra keys (dify-i18n/no-extra-keys rule)
|
||||
|
||||
### Step 3.2: Run Final i18n Check
|
||||
```bash
|
||||
pnpm --dir ${{ github.workspace }}/web run i18n:check
|
||||
```
|
||||
|
||||
### Step 3.3: Fix Any Remaining Issues
|
||||
If check reports issues:
|
||||
- Go back to PHASE 2 for unresolved items
|
||||
- Repeat until check passes
|
||||
|
||||
### Step 3.4: Generate Final Summary
|
||||
```
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ SYNC COMPLETED SUMMARY ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ Language │ Added │ Updated │ Deleted │ Status ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
|
||||
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
|
||||
║ ... │ ... │ ... │ ... │ ... ║
|
||||
╠══════════════════════════════════════════════════════════════╣
|
||||
║ i18n:check │ PASSED - All keys in sync ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
```
|
||||
|
||||
## Mode-Specific Behavior
|
||||
|
||||
**SYNC_MODE = "incremental"** (default):
|
||||
- Focus on keys identified from git diff
|
||||
- Also check i18n:check output for any missing/extra keys
|
||||
- Efficient for small changes
|
||||
|
||||
**SYNC_MODE = "full"**:
|
||||
- Compare ALL keys between en-US and each language
|
||||
- Run i18n:check to identify all discrepancies
|
||||
- Use for first-time sync or fixing historical issues
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. Always run i18n:check BEFORE and AFTER making changes
|
||||
2. The check script is the source of truth for missing/extra keys
|
||||
3. For UPDATE scenario: git diff is the source of truth for changed values
|
||||
4. Create a single commit with all translation changes
|
||||
5. If any translation fails, continue with others and report failures
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
║ PHASE 4: COMMIT AND CREATE PR ║
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
After all translations are complete and verified:
|
||||
|
||||
### Step 4.1: Check for changes
|
||||
```bash
|
||||
git -C ${{ github.workspace }} status --porcelain
|
||||
```
|
||||
|
||||
If there are changes:
|
||||
|
||||
### Step 4.2: Create a new branch and commit
|
||||
Run these git commands ONE BY ONE (not combined with &&).
|
||||
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
|
||||
|
||||
1. First, get the timestamp:
|
||||
```bash
|
||||
date +%Y%m%d-%H%M%S
|
||||
```
|
||||
(Note the output, e.g., "20260115-143052")
|
||||
|
||||
2. Then create branch using the timestamp value:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
|
||||
```
|
||||
(Replace "20260115-143052" with the actual timestamp from step 1)
|
||||
|
||||
3. Stage changes:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} add web/i18n/
|
||||
```
|
||||
|
||||
4. Commit:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
|
||||
```
|
||||
|
||||
5. Push:
|
||||
```bash
|
||||
git -C ${{ github.workspace }} push origin HEAD
|
||||
```
|
||||
|
||||
### Step 4.3: Create Pull Request
|
||||
```bash
|
||||
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
|
||||
|
||||
This PR was automatically generated to sync i18n translation files.
|
||||
|
||||
### Changes
|
||||
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
|
||||
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
|
||||
|
||||
### Verification
|
||||
- [x] \`i18n:check\` passed
|
||||
- [x] \`lint:fix\` applied
|
||||
|
||||
🤖 Generated with Claude Code GitHub Action" --base main
|
||||
```
|
||||
Tool rules:
|
||||
- Use Read for repository files.
|
||||
- Use Edit for JSON updates.
|
||||
- Use Bash only for `git`, `gh`, `pnpm`, and `date`.
|
||||
- Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution.
|
||||
|
||||
Required execution plan:
|
||||
1. Resolve target languages.
|
||||
- If no target languages were provided, read `${{ github.workspace }}/web/i18n-config/languages.ts` and use every language with `supported: true`.
|
||||
2. Stay strictly in scope.
|
||||
- Only process the files listed in `Files in scope`.
|
||||
- Only process the resolved target languages.
|
||||
- Do not touch unrelated i18n files.
|
||||
3. Detect English changes per file.
|
||||
- Read the current English JSON file for each file in scope.
|
||||
- If sync mode is `incremental` and `Base SHA` is not empty, run:
|
||||
`git -C ${{ github.workspace }} show <Base SHA>:web/i18n/en-US/<file>.json`
|
||||
- If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync.
|
||||
- If the file did not exist at Base SHA, treat all current keys as ADD.
|
||||
- Compare previous and current English JSON to identify:
|
||||
- ADD: key only in current
|
||||
- UPDATE: key exists in both and the English value changed
|
||||
- DELETE: key only in previous
|
||||
- Do not rely on a truncated diff file.
|
||||
4. Run a scoped pre-check before editing:
|
||||
- `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- Use this command as the source of truth for missing and extra keys inside the current scope.
|
||||
5. Apply translations.
|
||||
- For every target language and scoped file:
|
||||
- ADD missing keys.
|
||||
- UPDATE stale translations when the English value changed.
|
||||
- DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope.
|
||||
- For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation.
|
||||
- Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names.
|
||||
- Match the existing terminology and register used by each locale.
|
||||
- Prefer one Edit per file when stable, but prioritize correctness over batching.
|
||||
6. Verify only the edited files.
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- <relative edited i18n file paths>`
|
||||
- Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}`
|
||||
- If verification fails, fix the remaining problems before continuing.
|
||||
7. Create a PR only when there are changes in `web/i18n/`.
|
||||
- Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/`
|
||||
- Create branch `chore/i18n-sync-<timestamp>`
|
||||
- Commit message: `chore(i18n): sync translations with en-US`
|
||||
- Push the branch and open a PR against `main`
|
||||
- PR title: `chore(i18n): sync translations with en-US`
|
||||
- PR body: summarize files, languages, sync mode, and verification commands
|
||||
8. If there are no translation changes after verification, do not create a branch, commit, or PR.
|
||||
claude_args: |
|
||||
--max-turns 150
|
||||
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
|
||||
--max-turns 80
|
||||
--allowedTools "Read,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"
|
||||
|
||||
66
.github/workflows/trigger-i18n-sync.yml
vendored
66
.github/workflows/trigger-i18n-sync.yml
vendored
@@ -1,66 +0,0 @@
|
||||
name: Trigger i18n Sync on Push
|
||||
|
||||
# This workflow bridges the push event to repository_dispatch
|
||||
# because claude-code-action doesn't support push events directly.
|
||||
# See: https://github.com/langgenius/dify/issues/30743
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'web/i18n/en-US/*.json'
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
trigger:
|
||||
if: github.repository == 'langgenius/dify'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Detect changed files and generate diff
|
||||
id: detect
|
||||
run: |
|
||||
BEFORE_SHA="${{ github.event.before }}"
|
||||
# Handle edge case: force push may have null/zero SHA
|
||||
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
|
||||
BEFORE_SHA="HEAD~1"
|
||||
fi
|
||||
|
||||
# Detect changed i18n files
|
||||
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
|
||||
echo "changed_files=$changed" >> $GITHUB_OUTPUT
|
||||
|
||||
# Generate diff for context
|
||||
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
|
||||
|
||||
# Truncate if too large (keep first 50KB to match receiving workflow)
|
||||
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
|
||||
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
|
||||
|
||||
# Base64 encode the diff for safe JSON transport (portable, single-line)
|
||||
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
|
||||
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
|
||||
|
||||
if [ -n "$changed" ]; then
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
echo "Detected changed files: $changed"
|
||||
else
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
echo "No i18n changes detected"
|
||||
fi
|
||||
|
||||
- name: Trigger i18n sync workflow
|
||||
if: steps.detect.outputs.has_changes == 'true'
|
||||
uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
event-type: i18n-sync
|
||||
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'
|
||||
5
.github/workflows/web-tests.yml
vendored
5
.github/workflows/web-tests.yml
vendored
@@ -22,8 +22,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -66,7 +66,6 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
|
||||
7
Makefile
7
Makefile
@@ -74,6 +74,12 @@ type-check:
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@if [ -n "$(TARGET_TESTS)" ]; then \
|
||||
@@ -133,6 +139,7 @@ help:
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@@ -127,7 +127,8 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
|
||||
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||
|
||||
@@ -8,6 +8,7 @@ Go admin-api caller.
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
@@ -87,7 +88,7 @@ class EnterpriseAppDSLExport(Resource):
|
||||
"""Export an app's DSL as YAML."""
|
||||
include_secret = request.args.get("include_secret", "false").lower() == "true"
|
||||
|
||||
app_model = db.session.query(App).filter_by(id=app_id).first()
|
||||
app_model = db.session.get(App, app_id)
|
||||
if not app_model:
|
||||
return {"message": "app not found"}, 404
|
||||
|
||||
@@ -104,7 +105,7 @@ def _get_active_account(email: str) -> Account | None:
|
||||
|
||||
Workspace membership is already validated by the Go admin-api caller.
|
||||
"""
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
if account is None or account.status != AccountStatus.ACTIVE:
|
||||
return None
|
||||
return account
|
||||
|
||||
@@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@@ -104,11 +104,14 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
db.session.scalar(
|
||||
select(func.count())
|
||||
.select_from(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler:
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
)
|
||||
else:
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
db.session.execute(
|
||||
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
if not (tenant := db.session.get(Tenant, tenant_id)):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
assert tenant.encrypt_public_key is not None
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
|
||||
@@ -10,6 +10,7 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
@@ -410,8 +411,8 @@ class LLMGenerator:
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
last_run: Message | None = db.session.scalar(
|
||||
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
|
||||
@@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@@ -31,7 +31,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
|
||||
@@ -255,11 +255,11 @@ class ToolManager:
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
@@ -818,13 +818,13 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -872,13 +872,13 @@ class ToolManager:
|
||||
get api provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
provider_obj: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
@@ -964,10 +964,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if workflow_provider is None:
|
||||
@@ -981,10 +981,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
api_provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if api_provider is None:
|
||||
|
||||
@@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
|
||||
@@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
|
||||
@@ -35,15 +35,13 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
data_source_api_key_bindings = db.session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
@@ -54,10 +52,11 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
.first()
|
||||
data_source_api_key_binding = db.session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id,
|
||||
)
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from unittest.mock import patch
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@@ -534,3 +536,283 @@ class TestWorkspaceService:
|
||||
# Verify database state
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_join_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""TenantAccountJoin must exist; missing join should raise AssertionError."""
|
||||
fake = Faker()
|
||||
account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active")
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=fake.company(), status="normal", plan="basic")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# No TenantAccountJoin created
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
import json
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tenant.custom_config = json.dumps({})
|
||||
db_session_with_containers.commit()
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
import json
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tenant.custom_config = json.dumps({"replace_webapp_logo": True})
|
||||
db_session_with_containers.commit()
|
||||
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = custom_base
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED"
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
def test_get_tenant_info_cloud_credit_reset_date(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
def test_get_tenant_info_cloud_paid_pool_not_full(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""trial_credits come from paid pool when plan is not sandbox and pool is not full."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=1000, quota_used=200)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
def test_get_tenant_info_cloud_paid_pool_unlimited(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""quota_limit == -1 means unlimited; service should use paid pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=-1, quota_used=999)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When paid pool is exhausted, switch to trial pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=500, quota_used=500)
|
||||
trial_pool = MagicMock(quota_limit=100, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
trial_pool = MagicMock(quota_limit=50, quota_used=5)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
def test_get_tenant_info_cloud_sandbox_uses_trial_pool(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When plan is SANDBOX, skip paid pool and use trial pool."""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=1000, quota_used=0)
|
||||
trial_pool = MagicMock(quota_limit=200, quota_used=20)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
def test_get_tenant_info_cloud_both_pools_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
@@ -64,18 +64,18 @@ class TestGetActiveAccount:
|
||||
def test_returns_active_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "active"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("user@example.com")
|
||||
|
||||
assert result is mock_account
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_inactive_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "banned"
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
|
||||
result = _get_active_account("banned@example.com")
|
||||
|
||||
@@ -83,7 +83,7 @@ class TestGetActiveAccount:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_nonexistent_email(self, mock_db):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
|
||||
result = _get_active_account("missing@example.com")
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "yaml-data"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.get.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
|
||||
@@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
session.scalar.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
|
||||
@@ -114,13 +114,9 @@ class TestOnToolEnd:
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
@@ -138,13 +134,9 @@ class TestOnToolEnd:
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
|
||||
@@ -38,13 +38,13 @@ class TestObfuscatedToken:
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@@ -52,10 +52,10 @@ class TestEncryptToken:
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
mock_query.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
@@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@@ -181,12 +181,12 @@ class TestSecurity:
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@@ -205,13 +205,13 @@ class TestEdgeCases:
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@@ -219,13 +219,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@@ -242,13 +242,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_query.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@@ -314,8 +314,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
@@ -328,12 +328,12 @@ class TestLLMGenerator:
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
last_run = MagicMock()
|
||||
last_run.query = "q"
|
||||
last_run.answer = "a"
|
||||
last_run.error = "e"
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
mock_scalar.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
@@ -483,8 +483,8 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
@@ -504,8 +504,8 @@ class TestLLMGenerator:
|
||||
assert "current_val" in user_msg_dict["instruction"]
|
||||
|
||||
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No braces here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@@ -516,8 +516,8 @@ class TestLLMGenerator:
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@@ -556,8 +556,8 @@ class TestLLMGenerator:
|
||||
)
|
||||
|
||||
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@@ -566,8 +566,8 @@ class TestLLMGenerator:
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@@ -576,8 +576,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||
|
||||
@@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation:
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
|
||||
def test_get_app_returns_app(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
app_obj = MagicMock(id="app")
|
||||
query_chain.first.return_value = app_obj
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
query_chain.first.return_value = None
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels():
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
delete_query = mock_db.session.query.return_value.where.return_value
|
||||
delete_query.delete.return_value = None
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
delete_query.delete.assert_called_once()
|
||||
mock_db.session.execute.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@@ -220,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
builtin_provider
|
||||
)
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key": "secret"}
|
||||
cache = Mock()
|
||||
@@ -274,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
|
||||
)
|
||||
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"token": "old"}
|
||||
encrypter.encrypt.return_value = {"token": "encrypted"}
|
||||
@@ -698,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch(
|
||||
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
|
||||
) as mock_from_db:
|
||||
@@ -730,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value = db_query
|
||||
mock_db.session.scalar.return_value = provider
|
||||
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key_value": "secret"}
|
||||
@@ -750,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
|
||||
def test_get_api_provider_controller_not_found_raises():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
|
||||
ToolManager.get_api_provider_controller("tenant-1", "missing")
|
||||
|
||||
@@ -809,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
|
||||
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
|
||||
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
|
||||
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db.session.scalar.return_value = None
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
|
||||
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
db_session.get.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
@@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
|
||||
db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
db_session.get.side_effect = [
|
||||
SimpleNamespace(id="dataset-2", name="Dataset Two"),
|
||||
SimpleNamespace(id="dataset-1", name="Dataset One"),
|
||||
]
|
||||
|
||||
@@ -1,558 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DocumentStub:
|
||||
id: str
|
||||
name: str
|
||||
uploader: str
|
||||
upload_date: datetime
|
||||
last_update_date: datetime
|
||||
data_source_type: str
|
||||
doc_metadata: dict[str, object] | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
mocked_db = mocker.patch("services.metadata_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.metadata_service.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(mocker: MockerFixture) -> MagicMock:
|
||||
mock_user = SimpleNamespace(id="user-1")
|
||||
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
|
||||
|
||||
|
||||
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
|
||||
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
|
||||
return _DocumentStub(
|
||||
id=document_id,
|
||||
name=f"doc-{document_id}",
|
||||
uploader="qa@example.com",
|
||||
upload_date=now,
|
||||
last_update_date=now,
|
||||
data_source_type="upload_file",
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _dataset(**kwargs: Any) -> Dataset:
|
||||
return cast(Dataset, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="x" * 256)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="priority")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_persist_metadata_when_input_is_valid(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="number", name="score")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.dataset_id == "dataset-1"
|
||||
assert result.type == "number"
|
||||
assert result.name == "score"
|
||||
assert result.created_by == "user-1"
|
||||
mock_db.session.add.assert_called_once_with(result)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
too_long_name = "x" * 256
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
|
||||
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
|
||||
|
||||
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
|
||||
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
|
||||
|
||||
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
|
||||
mock_get_documents.return_value = [doc_1, doc_2]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert metadata.name == "new_name"
|
||||
assert metadata.updated_by == "user-1"
|
||||
assert metadata.updated_at == fixed_now
|
||||
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
|
||||
assert doc_2.doc_metadata == {"new_name": None}
|
||||
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
|
||||
bindings = [SimpleNamespace(document_id="doc-1")]
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_metadata, query_bindings]
|
||||
|
||||
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert document.doc_metadata == {"remaining": "value"}
|
||||
mock_db.session.delete.assert_called_once_with(metadata)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_delete_metadata_should_return_none_when_metadata_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "missing-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
|
||||
# Arrange
|
||||
expected_names = {
|
||||
BuiltInField.document_name,
|
||||
BuiltInField.uploader,
|
||||
BuiltInField.upload_date,
|
||||
BuiltInField.last_update_date,
|
||||
BuiltInField.source,
|
||||
}
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_built_in_fields()
|
||||
|
||||
# Assert
|
||||
assert {item["name"] for item in result} == expected_names
|
||||
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
doc_1 = _build_document("1", {"custom": "value"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[doc_1, doc_2],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert doc_1.doc_metadata is not None
|
||||
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
|
||||
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert doc_2.doc_metadata is not None
|
||||
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document(
|
||||
"1",
|
||||
{
|
||||
BuiltInField.document_name: "doc",
|
||||
BuiltInField.uploader: "user",
|
||||
BuiltInField.upload_date: 1.0,
|
||||
BuiltInField.last_update_date: 2.0,
|
||||
BuiltInField.source: MetadataDataSource.upload_file,
|
||||
"custom": "keep",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[document],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is False
|
||||
assert document.doc_metadata == {"custom": "keep"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
document = _build_document("1", {"legacy": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
delete_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
delete_chain.delete.return_value = 1
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
|
||||
partial_update=False,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata == {"priority": "high"}
|
||||
delete_chain.delete.assert_called_once()
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document("1", {"existing": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata is not None
|
||||
assert document.doc_metadata["existing"] == "value"
|
||||
assert document.doc_metadata["new_key"] == "new_value"
|
||||
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
|
||||
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dataset_id", "document_id", "expected_key"),
|
||||
[
|
||||
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
|
||||
(None, "doc-1", "document_metadata_lock_doc-1"),
|
||||
],
|
||||
)
|
||||
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
|
||||
dataset_id: str | None,
|
||||
document_id: str | None,
|
||||
expected_key: str,
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="document metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(
|
||||
id="dataset-1",
|
||||
built_in_field_enabled=True,
|
||||
doc_metadata=[
|
||||
{"id": "meta-1", "name": "priority", "type": "string"},
|
||||
{"id": "built-in", "name": "ignored", "type": "string"},
|
||||
{"id": "meta-2", "name": "score", "type": "number"},
|
||||
],
|
||||
)
|
||||
count_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
count_chain.count.side_effect = [3, 1]
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result["built_in_field_enabled"] is True
|
||||
assert result["doc_metadata"] == [
|
||||
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
|
||||
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
|
||||
mock_db.session.query.assert_not_called()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,576 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants used throughout the tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TENANT_ID = "tenant-abc"
|
||||
ACCOUNT_ID = "account-xyz"
|
||||
FILES_BASE_URL = "https://files.example.com"
|
||||
|
||||
DB_PATH = "services.workspace_service.db"
|
||||
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
|
||||
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
|
||||
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
|
||||
CURRENT_USER_PATH = "services.workspace_service.current_user"
|
||||
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tenant(
|
||||
tenant_id: str = TENANT_ID,
|
||||
name: str = "My Workspace",
|
||||
plan: str = "sandbox",
|
||||
status: str = "active",
|
||||
custom_config: dict | None = None,
|
||||
) -> Tenant:
|
||||
"""Create a minimal Tenant-like namespace."""
|
||||
return cast(
|
||||
Tenant,
|
||||
SimpleNamespace(
|
||||
id=tenant_id,
|
||||
name=name,
|
||||
plan=plan,
|
||||
status=status,
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
custom_config_dict=custom_config or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_feature(
|
||||
can_replace_logo: bool = False,
|
||||
next_credit_reset_date: str | None = None,
|
||||
billing_plan: str = "sandbox",
|
||||
) -> MagicMock:
|
||||
"""Create a feature namespace matching what FeatureService.get_features returns."""
|
||||
feature = MagicMock()
|
||||
feature.can_replace_logo = can_replace_logo
|
||||
feature.next_credit_reset_date = next_credit_reset_date
|
||||
feature.billing.subscription.plan = billing_plan
|
||||
return feature
|
||||
|
||||
|
||||
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
|
||||
pool = MagicMock()
|
||||
pool.quota_limit = quota_limit
|
||||
pool.quota_used = quota_used
|
||||
return pool
|
||||
|
||||
|
||||
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(role=role)
|
||||
|
||||
|
||||
def _tenant_info(result: object) -> dict[str, Any] | None:
|
||||
return cast(dict[str, Any] | None, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user() -> SimpleNamespace:
|
||||
"""Return a lightweight current_user stand-in."""
|
||||
return SimpleNamespace(id=ACCOUNT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""
|
||||
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
|
||||
|
||||
Returns a dict of named mocks so individual tests can customise them.
|
||||
"""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
|
||||
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "SELF_HOSTED"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"has_roles": mock_has_roles,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. None Tenant Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
|
||||
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = None
|
||||
|
||||
# Act
|
||||
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
|
||||
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange / Act / Assert
|
||||
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Basic Tenant Info — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_base_fields(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""get_tenant_info should always return the six base scalar fields."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["id"] == TENANT_ID
|
||||
assert result["name"] == "My Workspace"
|
||||
assert result["plan"] == "sandbox"
|
||||
assert result["status"] == "active"
|
||||
assert result["created_at"] == "2024-01-01T00:00:00Z"
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["role"] == "admin"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The service asserts that TenantAccountJoin exists.
|
||||
Missing join should raise AssertionError.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = None
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Logo Customisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(
|
||||
custom_config={
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is True
|
||||
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
|
||||
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config should be absent when can_replace_logo is False."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block is gated on OWNER or ADMIN role."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = False # regular member
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
basic_mocks["config"].FILES_URL = custom_base
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Cloud-Edition Credit Features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""Patches for CLOUD edition tests, billing plan = professional by default."""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(
|
||||
FEATURE_SERVICE_PATH,
|
||||
return_value=_make_feature(
|
||||
can_replace_logo=False,
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
|
||||
),
|
||||
)
|
||||
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "CLOUD"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
CREDIT_POOL_SERVICE_PATH,
|
||||
side_effect=[None, None], # both paid and trial pools absent
|
||||
)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
|
||||
trial_pool = _make_pool(quota_limit=100, quota_used=10)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
trial_pool = _make_pool(quota_limit=50, quota_used=5)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
When the subscription plan IS SANDBOX, the paid pool branch is skipped
|
||||
entirely and we fall back to the trial pool.
|
||||
"""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange — override billing plan to SANDBOX
|
||||
cloud_mocks["get_features"].return_value = _make_feature(
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CloudPlan.SANDBOX,
|
||||
)
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
|
||||
trial_pool = _make_pool(quota_limit=200, quota_used=20)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
|
||||
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Self-hosted / Non-Cloud Edition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DB query integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The DB query for TenantAccountJoin must be scoped to the correct
|
||||
tenant_id and current_user.id.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant(tenant_id="my-special-tenant")
|
||||
mock_current_user = mocker.patch(CURRENT_USER_PATH)
|
||||
mock_current_user.id = "special-user-id"
|
||||
|
||||
# Act
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert — db.session.query was invoked (at least once)
|
||||
basic_mocks["db_session"].query.assert_called()
|
||||
@@ -488,7 +488,8 @@ ALIYUN_OSS_REGION=ap-southeast-1
|
||||
ALIYUN_OSS_AUTH_VERSION=v4
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
|
||||
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Tencent COS Configuration
|
||||
#
|
||||
|
||||
@@ -275,6 +275,7 @@ services:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
|
||||
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
|
||||
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
|
||||
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
|
||||
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
|
||||
@@ -146,7 +146,6 @@ x-shared-env: &shared-api-worker-env
|
||||
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
|
||||
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
|
||||
ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id}
|
||||
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
|
||||
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
|
||||
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
|
||||
@@ -985,6 +984,7 @@ services:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
|
||||
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
|
||||
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
|
||||
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
|
||||
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
|
||||
Reference in New Issue
Block a user