mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
Compare commits
2 Commits
feat/remov
...
refactor-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8cabc9bdb | ||
|
|
45e2d4627f |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -122,7 +122,7 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
|
||||
api/controllers/web/feature.py @GarfieldDai @GareArc
|
||||
|
||||
# Backend - Database Migrations
|
||||
api/migrations/ @snakevash @laipz8200 @MRZHUH
|
||||
api/migrations/ @snakevash @laipz8200
|
||||
|
||||
# Frontend
|
||||
web/ @iamjoel
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -79,7 +79,7 @@ jobs:
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
working-directory: ./web
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -90,7 +90,7 @@ jobs:
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Web dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
@@ -55,7 +55,7 @@ jobs:
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
|
||||
358
.github/workflows/web-tests.yml
vendored
358
.github/workflows/web-tests.yml
vendored
@@ -13,7 +13,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
@@ -22,7 +21,14 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@v46
|
||||
with:
|
||||
files: web/**
|
||||
|
||||
- name: Install pnpm
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
@@ -30,355 +36,23 @@ jobs:
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
with:
|
||||
node-version: 22
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Restore Jest cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: web/.cache/jest
|
||||
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-jest-
|
||||
cache-dependency-path: ./web/package.json
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Check i18n types synchronization
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run check:i18n-types
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pnpm exec jest \
|
||||
--ci \
|
||||
--maxWorkers=100% \
|
||||
--coverage \
|
||||
--passWithNoTests
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
set -eo pipefail
|
||||
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Jest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: web-coverage-report
|
||||
path: web/coverage
|
||||
retention-days: 30
|
||||
if-no-files-found: error
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm test
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
@@ -45,6 +47,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
|
||||
from services.plugin.data_migration import PluginDataMigration
|
||||
from services.plugin.plugin_migration import PluginMigration
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
|
||||
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1900,3 +1903,76 @@ def migrate_oss(
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
|
||||
|
||||
|
||||
@click.command("clean-expired-sandbox-messages", help="Clean expired sandbox messages.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
|
||||
@click.option(
|
||||
"--graceful-period",
|
||||
default=21,
|
||||
show_default=True,
|
||||
help="Graceful period in days after subscription expiration.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.",
|
||||
)
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing")
|
||||
def clean_expired_sandbox_messages(
|
||||
batch_size: int,
|
||||
graceful_period: int,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
dry_run: bool,
|
||||
):
|
||||
"""
|
||||
Clean expired messages and related data for sandbox tenants.
|
||||
"""
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
|
||||
return
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
graceful_period=graceful_period,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: completed successfully\n"
|
||||
f" - Latency: {end_at - start_at:.2f}s\n"
|
||||
f" - Batches processed: {stats['batches']}\n"
|
||||
f" - Messages found: {stats['total_messages']}\n"
|
||||
f" - Messages deleted: {stats['total_deleted']}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_at = time.perf_counter()
|
||||
logger.exception("clean_messages failed")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
click.echo(click.style("Sandbox messages cleanup completed.", fg="green"))
|
||||
|
||||
@@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
|
||||
|
||||
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
|
||||
default=600.0,
|
||||
default=300.0,
|
||||
)
|
||||
|
||||
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")
|
||||
|
||||
@@ -338,7 +338,9 @@ class CompletionConversationApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
|
||||
query = sa.select(Conversation).where(
|
||||
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
|
||||
)
|
||||
|
||||
if args.keyword:
|
||||
query = query.join(Message, Message.conversation_id == Conversation.id).where(
|
||||
@@ -450,7 +452,7 @@ class ChatConversationApi(Resource):
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id)
|
||||
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
|
||||
|
||||
if args.keyword:
|
||||
keyword_filter = f"%{args.keyword}%"
|
||||
|
||||
@@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
||||
@@ -40,7 +40,7 @@ from .. import console_ns
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessageExplorePayload(BaseModel):
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
@@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
|
||||
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessageP
|
||||
endpoint="installed_app_completion",
|
||||
)
|
||||
class CompletionApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
|
||||
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
|
||||
def post(self, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
|
||||
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
@@ -1,40 +1,31 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from fields.tag_fields import dataset_tag_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import Tag
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
class TagBasePayload(BaseModel):
|
||||
name: str = Field(description="Tag name", min_length=1, max_length=50)
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 50:
|
||||
raise ValueError("Name must be between 1 to 50 characters.")
|
||||
return name
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str] = Field(description="Tag IDs to bind")
|
||||
target_id: str = Field(description="Target ID to bind tags to")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
class TagBindingRemovePayload(BaseModel):
|
||||
tag_id: str = Field(description="Tag ID to remove")
|
||||
target_id: str = Field(description="Target ID to unbind tag from")
|
||||
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
TagBasePayload,
|
||||
TagBindingPayload,
|
||||
TagBindingRemovePayload,
|
||||
parser_tags = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@@ -52,7 +43,7 @@ class TagListApi(Resource):
|
||||
|
||||
return tags, 200
|
||||
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@console_ns.expect(parser_tags)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -62,17 +53,22 @@ class TagListApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.save_tags(payload.model_dump())
|
||||
args = parser_tags.parse_args()
|
||||
tag = TagService.save_tags(args)
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
|
||||
return response, 200
|
||||
|
||||
|
||||
parser_tag_id = reqparse.RequestParser().add_argument(
|
||||
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tags/<uuid:tag_id>")
|
||||
class TagUpdateDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
|
||||
@console_ns.expect(parser_tag_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -83,8 +79,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBasePayload.model_validate(console_ns.payload or {})
|
||||
tag = TagService.update_tags(payload.model_dump(), tag_id)
|
||||
args = parser_tag_id.parse_args()
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@@ -104,9 +100,17 @@ class TagUpdateDeleteApi(Resource):
|
||||
return 204
|
||||
|
||||
|
||||
parser_create = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
class TagBindingCreateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
|
||||
@console_ns.expect(parser_create)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -116,15 +120,23 @@ class TagBindingCreateApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(payload.model_dump())
|
||||
args = parser_create.parse_args()
|
||||
TagService.save_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_remove = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
class TagBindingDeleteApi(Resource):
|
||||
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
|
||||
@console_ns.expect(parser_remove)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -134,7 +146,7 @@ class TagBindingDeleteApi(Resource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(payload.model_dump())
|
||||
args = parser_remove.parse_args()
|
||||
TagService.delete_tag_binding(args)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, field_validator
|
||||
from flask_restx import fields, marshal_with, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
@@ -21,7 +20,6 @@ from controllers.web.error import (
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
@@ -31,25 +29,6 @@ from services.errors.audio import (
|
||||
UnsupportedAudioTypeServiceError,
|
||||
)
|
||||
|
||||
from ..common.schema import register_schema_models
|
||||
|
||||
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
@field_validator("message_id")
|
||||
@classmethod
|
||||
def validate_message_id(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, TextToAudioPayload)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -109,7 +88,6 @@ class AudioApi(WebApiResource):
|
||||
|
||||
@web_ns.route("/text-to-audio")
|
||||
class TextApi(WebApiResource):
|
||||
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
|
||||
@web_ns.doc("Text to Audio")
|
||||
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
|
||||
@web_ns.doc(
|
||||
@@ -124,11 +102,18 @@ class TextApi(WebApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
"""Convert text to audio"""
|
||||
try:
|
||||
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json")
|
||||
.add_argument("voice", type=str, location="json")
|
||||
.add_argument("text", type=str, location="json")
|
||||
.add_argument("streaming", type=bool, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from flask_restx import reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
@@ -36,44 +34,25 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompletionMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the completion")
|
||||
query: str = Field(default="", description="Query text for completion")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
|
||||
class ChatMessagePayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(description="Input variables for the chat")
|
||||
query: str = Field(description="User query/message")
|
||||
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
|
||||
response_mode: Literal["blocking", "streaming"] | None = Field(
|
||||
default=None, description="Response mode: blocking or streaming"
|
||||
)
|
||||
conversation_id: str | None = Field(default=None, description="Conversation ID")
|
||||
parent_message_id: str | None = Field(default=None, description="Parent message ID")
|
||||
retriever_from: str = Field(default="web_app", description="Source of retriever")
|
||||
|
||||
@field_validator("conversation_id", "parent_message_id")
|
||||
@classmethod
|
||||
def validate_uuid(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return uuid_value(value)
|
||||
|
||||
|
||||
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
|
||||
|
||||
|
||||
# define completion api for user
|
||||
@web_ns.route("/completion-messages")
|
||||
class CompletionApi(WebApiResource):
|
||||
@web_ns.doc("Create Completion Message")
|
||||
@web_ns.doc(description="Create a completion message for text generation applications.")
|
||||
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
|
||||
"query": {"description": "Query text for completion", "type": "string", "required": False},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
@@ -88,10 +67,18 @@ class CompletionApi(WebApiResource):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, location="json", default="")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
@@ -155,7 +142,22 @@ class CompletionStopApi(WebApiResource):
|
||||
class ChatApi(WebApiResource):
|
||||
@web_ns.doc("Create Chat Message")
|
||||
@web_ns.doc(description="Create a chat message for conversational applications.")
|
||||
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
|
||||
@web_ns.doc(
|
||||
params={
|
||||
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
|
||||
"query": {"description": "User query/message", "type": "string", "required": True},
|
||||
"files": {"description": "Files to be processed", "type": "array", "required": False},
|
||||
"response_mode": {
|
||||
"description": "Response mode: blocking or streaming",
|
||||
"type": "string",
|
||||
"enum": ["blocking", "streaming"],
|
||||
"required": False,
|
||||
},
|
||||
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
|
||||
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
|
||||
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
|
||||
}
|
||||
)
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Success",
|
||||
@@ -171,10 +173,20 @@ class ChatApi(WebApiResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json")
|
||||
.add_argument("query", type=str, required=True, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
|
||||
)
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal
|
||||
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
json_schema: str | None = Field(default=None)
|
||||
json_schema: dict[str, Any] | None = Field(default=None)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
|
||||
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, schema: str | None) -> str | None:
|
||||
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if schema is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"invalid json_schema value {schema}")
|
||||
|
||||
try:
|
||||
Draft7Validator.check_schema(json_schema)
|
||||
Draft7Validator.check_schema(schema)
|
||||
except SchemaError as e:
|
||||
raise ValueError(f"Invalid JSON schema: {e.message}")
|
||||
return schema
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union, final
|
||||
|
||||
@@ -176,13 +175,6 @@ class BaseAppGenerator:
|
||||
value = True
|
||||
elif value == 0:
|
||||
value = False
|
||||
case VariableEntityType.JSON_OBJECT:
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a string")
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
|
||||
case _:
|
||||
raise AssertionError("this statement should be unreachable.")
|
||||
|
||||
|
||||
@@ -342,11 +342,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
self._task_state.llm_result.message.content = current_content
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
|
||||
yield self._message_cycle_manager.message_to_stream_response(
|
||||
answer=cast(str, delta_text),
|
||||
message_id=self._message_id,
|
||||
event_type=event_type,
|
||||
)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(
|
||||
|
||||
@@ -5,7 +5,7 @@ from threading import Thread
|
||||
from typing import Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@@ -54,20 +54,6 @@ class MessageCycleManager:
|
||||
):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
self._message_has_file: set[str] = set()
|
||||
|
||||
def get_message_event_type(self, message_id: str) -> StreamEvent:
|
||||
if message_id in self._message_has_file:
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
|
||||
|
||||
if has_file:
|
||||
self._message_has_file.add(message_id)
|
||||
return StreamEvent.MESSAGE_FILE
|
||||
|
||||
return StreamEvent.MESSAGE
|
||||
|
||||
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
|
||||
"""
|
||||
@@ -228,11 +214,7 @@ class MessageCycleManager:
|
||||
return None
|
||||
|
||||
def message_to_stream_response(
|
||||
self,
|
||||
answer: str,
|
||||
message_id: str,
|
||||
from_variable_selector: list[str] | None = None,
|
||||
event_type: StreamEvent | None = None,
|
||||
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
@@ -240,12 +222,16 @@ class MessageCycleManager:
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
event=event_type or StreamEvent.MESSAGE,
|
||||
event=event_type,
|
||||
)
|
||||
|
||||
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
|
||||
|
||||
@@ -39,7 +39,7 @@ from core.trigger.errors import (
|
||||
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
|
||||
_plugin_daemon_timeout_config = cast(
|
||||
float | httpx.Timeout | None,
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
|
||||
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
|
||||
)
|
||||
plugin_daemon_request_timeout: httpx.Timeout | None
|
||||
if _plugin_daemon_timeout_config is None:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
@@ -53,7 +52,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
|
||||
self._fixed_separator = fixed_separator
|
||||
self._separators = separators or ["\n\n", "\n", "。", ". ", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> list[str]:
|
||||
@@ -95,8 +94,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
splits = re.split(r" +", text)
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
if self._keep_separator:
|
||||
splits = [s + separator for s in splits[:-1]] + splits[-1:]
|
||||
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
|
||||
else:
|
||||
splits = list(text)
|
||||
if separator == "\n":
|
||||
@@ -105,7 +103,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
|
||||
splits = [s for s in splits if (s not in {"", "\n"})]
|
||||
_good_splits = []
|
||||
_good_splits_lengths = [] # cache the lengths of the splits
|
||||
_separator = "" if self._keep_separator else separator
|
||||
_separator = separator if self._keep_separator else ""
|
||||
s_lens = self._length_function(splits)
|
||||
if separator != "":
|
||||
for s, s_len in zip(splits, s_lens):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
@@ -43,25 +42,15 @@ class StartNode(Node[StartNodeData]):
|
||||
if value is None and variable.required:
|
||||
raise ValueError(f"{key} is required in input form")
|
||||
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"{key} must be a JSON object")
|
||||
|
||||
schema = variable.json_schema
|
||||
if not schema:
|
||||
continue
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
try:
|
||||
json_schema = json.loads(schema)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{schema} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
json_value = json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"{value} must be a valid JSON object")
|
||||
|
||||
try:
|
||||
Draft7Validator(json_schema).validate(json_value)
|
||||
Draft7Validator(schema).validate(value)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
|
||||
node_inputs[key] = json_value
|
||||
node_inputs[key] = value
|
||||
|
||||
@@ -4,6 +4,7 @@ from dify_app import DifyApp
|
||||
def init_app(app: DifyApp):
|
||||
from commands import (
|
||||
add_qdrant_index,
|
||||
clean_expired_sandbox_messages,
|
||||
cleanup_orphaned_draft_variables,
|
||||
clear_free_plan_tenant_expired_logs,
|
||||
clear_orphaned_file_records,
|
||||
@@ -54,6 +55,7 @@ def init_app(app: DifyApp):
|
||||
setup_datasource_oauth_client,
|
||||
transform_datasource_credentials,
|
||||
install_rag_pipeline_plugins,
|
||||
clean_expired_sandbox_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -11,7 +11,6 @@ from collections.abc import Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@@ -120,19 +119,6 @@ def uuid_value(value: Any) -> str:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def normalize_uuid(value: str | UUID) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("must be a valid UUID") from exc
|
||||
|
||||
|
||||
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
|
||||
|
||||
|
||||
def alphanumeric(value: str):
|
||||
# check if the value is alphanumeric and underlined
|
||||
if re.match(r"^[a-zA-Z0-9_]+$", value):
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove unused is_deleted from conversations
|
||||
|
||||
Revision ID: e5d7a95e676f
|
||||
Revises: d57accd375ae
|
||||
Create Date: 2025-11-27 18:27:09.006691
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "e5d7a95e676f"
|
||||
down_revision = "d57accd375ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conversations = sa.table("conversations", sa.column("is_deleted", sa.Boolean))
|
||||
op.execute(sa.delete(conversations).where(conversations.c.is_deleted == sa.true()))
|
||||
|
||||
with op.batch_alter_table("conversations", schema=None) as batch_op:
|
||||
batch_op.drop_column("is_deleted")
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("conversations", schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""feat: add created_at id index to messages
|
||||
|
||||
Revision ID: 649d817a739e
|
||||
Revises: 03ea244985ce
|
||||
Create Date: 2025-12-18 16:39:33.090454
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '649d817a739e'
|
||||
down_revision = '03ea244985ce'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('messages', schema=None) as batch_op:
|
||||
batch_op.drop_index('message_created_at_id_idx')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@@ -676,6 +676,8 @@ class Conversation(Base):
|
||||
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
|
||||
)
|
||||
|
||||
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
@property
|
||||
def inputs(self) -> dict[str, Any]:
|
||||
inputs = self._inputs.copy()
|
||||
@@ -963,6 +965,7 @@ class Message(Base):
|
||||
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
|
||||
Index("message_created_at_idx", "created_at"),
|
||||
Index("message_app_mode_idx", "app_mode"),
|
||||
Index("message_created_at_id_idx", "created_at", "id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
|
||||
@@ -1,90 +1,54 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import (
|
||||
App,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.feature_service import FeatureService
|
||||
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.celery.task(queue="dataset")
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_messages():
|
||||
click.echo(click.style("Start clean messages.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
|
||||
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
# Main query with join and filter
|
||||
messages = (
|
||||
db.session.query(Message)
|
||||
.where(Message.created_at < plan_sandbox_clean_message_day)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
"""
|
||||
Clean expired messages from sandbox plan tenants.
|
||||
|
||||
except SQLAlchemyError:
|
||||
raise
|
||||
if not messages:
|
||||
break
|
||||
for message in messages:
|
||||
app = db.session.query(App).filter_by(id=message.app_id).first()
|
||||
if not app:
|
||||
logger.warning(
|
||||
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
|
||||
message.app_id,
|
||||
message.id,
|
||||
)
|
||||
continue
|
||||
features_cache_key = f"features:{app.tenant_id}"
|
||||
plan_cache = redis_client.get(features_cache_key)
|
||||
if plan_cache is None:
|
||||
features = FeatureService.get_features(app.tenant_id)
|
||||
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
|
||||
plan = features.billing.subscription.plan
|
||||
else:
|
||||
plan = plan_cache.decode()
|
||||
if plan == CloudPlan.SANDBOX:
|
||||
# clean related message
|
||||
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.session.query(Message).where(Message.id == message.id).delete()
|
||||
db.session.commit()
|
||||
end_at = time.perf_counter()
|
||||
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
|
||||
This task uses SandboxMessagesCleanService to efficiently clean messages in batches.
|
||||
"""
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
|
||||
return
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_days(
|
||||
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
|
||||
graceful_period=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
|
||||
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: completed successfully\n"
|
||||
f" - Latency: {end_at - start_at:.2f}s\n"
|
||||
f" - Batches processed: {stats['batches']}\n"
|
||||
f" - Messages found: {stats['total_messages']}\n"
|
||||
f" - Messages deleted: {stats['total_deleted']}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
end_at = time.perf_counter()
|
||||
logger.exception("clean_messages failed")
|
||||
click.echo(
|
||||
click.style(
|
||||
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
|
||||
fg="red",
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -47,6 +47,7 @@ class ConversationService:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.is_deleted == False,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
@@ -165,6 +166,7 @@ class ConversationService:
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
||||
description: str
|
||||
icon_info: IconInfo
|
||||
permission: str
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
|
||||
488
api/services/sandbox_messages_clean_service.py
Normal file
488
api/services/sandbox_messages_clean_service.py
Normal file
@@ -0,0 +1,488 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleMessage:
|
||||
"""Lightweight message info containing only essential fields for cleaning."""
|
||||
|
||||
id: str
|
||||
app_id: str
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
class SandboxMessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages from sandbox plan tenants.
|
||||
"""
|
||||
|
||||
# Redis key prefix for tenant plan cache
|
||||
PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
|
||||
# Cache TTL: 10 minutes
|
||||
PLAN_CACHE_TTL = 600
|
||||
|
||||
@classmethod
|
||||
def clean_sandbox_messages_by_time_range(
|
||||
cls,
|
||||
start_from: datetime.datetime,
|
||||
end_before: datetime.datetime,
|
||||
graceful_period: int = 21,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Clean sandbox messages within a specific time range [start_from, end_before).
|
||||
|
||||
Args:
|
||||
start_from: Start time (inclusive) of the range
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
|
||||
Returns:
|
||||
Statistics about the cleaning operation
|
||||
|
||||
Raises:
|
||||
ValueError: If start_from >= end_before
|
||||
"""
|
||||
if start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
if graceful_period < 0:
|
||||
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
|
||||
|
||||
logger.info("clean_messages: start_from=%s, end_before=%s, batch_size=%s", start_from, end_before, batch_size)
|
||||
|
||||
return cls._clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
graceful_period=graceful_period,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def clean_sandbox_messages_by_days(
|
||||
cls,
|
||||
days: int = 30,
|
||||
graceful_period: int = 21,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Clean sandbox messages older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
|
||||
Returns:
|
||||
Statistics about the cleaning operation
|
||||
"""
|
||||
if days < 0:
|
||||
raise ValueError(f"days ({days}) must be greater than or equal to 0")
|
||||
|
||||
if batch_size <= 0:
|
||||
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
|
||||
|
||||
if graceful_period < 0:
|
||||
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
|
||||
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
|
||||
logger.info("clean_messages: days=%s, end_before=%s, batch_size=%s", days, end_before, batch_size)
|
||||
|
||||
return cls._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
start_from=None,
|
||||
graceful_period=graceful_period,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _clean_sandbox_messages_by_time_range(
|
||||
cls,
|
||||
end_before: datetime.datetime,
|
||||
start_from: datetime.datetime | None = None,
|
||||
graceful_period: int = 21,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Internal method to clean sandbox messages within a time range using cursor-based pagination.
|
||||
Time range is [start_from, end_before) - left-closed, right-open interval.
|
||||
|
||||
Steps:
|
||||
1. Iterate messages using cursor pagination (by created_at, id)
|
||||
2. Extract app_ids from messages
|
||||
3. Query tenant_ids from apps
|
||||
4. Batch fetch subscription plans
|
||||
5. Delete messages from sandbox tenants
|
||||
|
||||
Args:
|
||||
end_before: End time (exclusive) of the range
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
|
||||
Returns:
|
||||
Dict with statistics: batches, total_messages, total_deleted
|
||||
"""
|
||||
stats = {
|
||||
"batches": 0,
|
||||
"total_messages": 0,
|
||||
"total_deleted": 0,
|
||||
}
|
||||
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
logger.info("clean_messages: billing is not enabled, skip cleaning messages")
|
||||
return stats
|
||||
|
||||
tenant_whitelist = cls._get_tenant_whitelist()
|
||||
logger.info("clean_messages: tenant_whitelist=%s", tenant_whitelist)
|
||||
|
||||
# Cursor-based pagination using (created_at, id) to avoid infinite loops
|
||||
# and ensure proper ordering with time-based filtering
|
||||
_cursor: tuple[datetime.datetime, str] | None = None
|
||||
|
||||
logger.info(
|
||||
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
|
||||
dry_run,
|
||||
start_from,
|
||||
end_before,
|
||||
)
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
msg_stmt = (
|
||||
select(Message.id, Message.app_id, Message.created_at)
|
||||
.where(Message.created_at < end_before)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
if start_from:
|
||||
msg_stmt = msg_stmt.where(Message.created_at >= start_from)
|
||||
|
||||
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
|
||||
# This translates to:
|
||||
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
|
||||
if _cursor:
|
||||
# Continuing from previous batch
|
||||
msg_stmt = msg_stmt.where(
|
||||
(Message.created_at > _cursor[0])
|
||||
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
|
||||
)
|
||||
|
||||
raw_messages = list(session.execute(msg_stmt).all())
|
||||
messages = [
|
||||
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
|
||||
for msg_id, app_id, msg_created_at in raw_messages
|
||||
]
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
_cursor = (messages[-1].created_at, messages[-1].id)
|
||||
|
||||
# Step 2: Extract app_ids from this batch
|
||||
app_ids = list({msg.app_id for msg in messages})
|
||||
|
||||
if not app_ids:
|
||||
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
# Step 3: Query tenant_ids from apps
|
||||
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
|
||||
apps = list(session.execute(app_stmt).all())
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
# Step 4: End sesion to call billing API to avoid long-running transaction.
|
||||
# Build app_id -> tenant_id mapping
|
||||
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
|
||||
tenant_ids = list(set(app_to_tenant.values()))
|
||||
|
||||
# Batch fetch subscription plans
|
||||
tenant_plans = cls._batch_fetch_tenant_plans(tenant_ids)
|
||||
|
||||
# Step 5: Filter messages from sandbox tenants
|
||||
sandbox_message_ids = cls._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=tenant_whitelist,
|
||||
graceful_period_days=graceful_period,
|
||||
)
|
||||
|
||||
if not sandbox_message_ids:
|
||||
logger.info("clean_messages (batch %s): no sandbox messages found, skip", stats["batches"])
|
||||
continue
|
||||
|
||||
stats["total_messages"] += len(sandbox_message_ids)
|
||||
|
||||
# Step 6: Batch delete messages and their relations
|
||||
if not dry_run:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Delete related records first
|
||||
cls._batch_delete_message_relations(session, sandbox_message_ids)
|
||||
|
||||
# Delete messages
|
||||
delete_stmt = delete(Message).where(Message.id.in_(sandbox_message_ids))
|
||||
delete_result = cast(CursorResult, session.execute(delete_stmt))
|
||||
messages_deleted = delete_result.rowcount
|
||||
session.commit()
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s sandbox messages",
|
||||
stats["batches"],
|
||||
len(messages),
|
||||
messages_deleted,
|
||||
)
|
||||
else:
|
||||
sample_ids = ", ".join(sample_id for sample_id in sandbox_message_ids[:5])
|
||||
logger.info(
|
||||
"clean_messages (batch %s, dry_run): would delete %s sandbox messages, sample ids: %s",
|
||||
stats["batches"],
|
||||
len(sandbox_message_ids),
|
||||
sample_ids,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
stats["total_messages"],
|
||||
stats["total_deleted"],
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
@classmethod
|
||||
def _filter_expired_sandbox_messages(
|
||||
cls,
|
||||
messages: Sequence[SimpleMessage],
|
||||
app_to_tenant: dict[str, str],
|
||||
tenant_plans: dict[str, SubscriptionPlan],
|
||||
tenant_whitelist: Sequence[str],
|
||||
graceful_period_days: int,
|
||||
current_timestamp: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Filter messages that should be deleted based on sandbox plan expiration.
|
||||
|
||||
A message should be deleted if:
|
||||
1. It belongs to a sandbox tenant AND
|
||||
2. Either:
|
||||
a) The tenant has no previous subscription (expiration_date == -1), OR
|
||||
b) The subscription expired more than graceful_period_days ago
|
||||
|
||||
Args:
|
||||
messages: List of message objects with id and app_id attributes
|
||||
app_to_tenant: Mapping from app_id to tenant_id
|
||||
tenant_plans: Mapping from tenant_id to subscription plan info
|
||||
graceful_period_days: Grace period in days after expiration
|
||||
current_timestamp: Current Unix timestamp (defaults to now, injectable for testing)
|
||||
|
||||
Returns:
|
||||
List of message IDs that should be deleted
|
||||
"""
|
||||
if current_timestamp is None:
|
||||
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
sandbox_message_ids: list[str] = []
|
||||
graceful_period_seconds = graceful_period_days * 24 * 60 * 60
|
||||
|
||||
for msg in messages:
|
||||
# Get tenant_id for this message's app
|
||||
tenant_id = app_to_tenant.get(msg.app_id)
|
||||
if not tenant_id:
|
||||
continue
|
||||
|
||||
# Skip tenant messages in whitelist
|
||||
if tenant_id in tenant_whitelist:
|
||||
continue
|
||||
|
||||
# Get subscription plan for this tenant
|
||||
tenant_plan = tenant_plans.get(tenant_id)
|
||||
if not tenant_plan:
|
||||
continue
|
||||
|
||||
plan = str(tenant_plan["plan"])
|
||||
expiration_date = int(tenant_plan["expiration_date"])
|
||||
|
||||
# Only process sandbox plans
|
||||
if plan != CloudPlan.SANDBOX:
|
||||
continue
|
||||
|
||||
# Case 1: No previous subscription (-1 means never had a paid subscription)
|
||||
if expiration_date == -1:
|
||||
sandbox_message_ids.append(msg.id)
|
||||
continue
|
||||
|
||||
# Case 2: Subscription expired beyond grace period
|
||||
if current_timestamp - expiration_date > graceful_period_seconds:
|
||||
sandbox_message_ids.append(msg.id)
|
||||
|
||||
return sandbox_message_ids
|
||||
|
||||
@classmethod
|
||||
def _get_tenant_whitelist(cls) -> Sequence[str]:
|
||||
return BillingService.get_expired_subscription_cleanup_whitelist()
|
||||
|
||||
@classmethod
|
||||
def _batch_fetch_tenant_plans(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
|
||||
"""
|
||||
Batch fetch tenant plans with Redis caching.
|
||||
|
||||
This method uses a two-tier strategy:
|
||||
1. First, batch fetch from Redis cache using mget
|
||||
2. For cache misses, fetch from billing API
|
||||
3. Update Redis cache using pipeline for new entries
|
||||
|
||||
Args:
|
||||
tenant_ids: List of tenant IDs
|
||||
|
||||
Returns:
|
||||
Dict mapping tenant_id to SubscriptionPlan (with "plan" and "expiration_date" keys)
|
||||
"""
|
||||
if not tenant_ids:
|
||||
return {}
|
||||
|
||||
tenant_plans: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
# Step 1: Batch fetch from Redis cache using mget
|
||||
redis_keys = [f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}" for tenant_id in tenant_ids]
|
||||
try:
|
||||
cached_values = redis_client.mget(redis_keys)
|
||||
|
||||
# Map cached values back to tenant_ids
|
||||
cache_hits: dict[str, SubscriptionPlan] = {}
|
||||
cache_misses: list[str] = []
|
||||
|
||||
for tenant_id, cached_value in zip(tenant_ids, cached_values):
|
||||
if cached_value:
|
||||
# Redis returns bytes, decode to string and parse JSON
|
||||
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
|
||||
try:
|
||||
plan_dict = json.loads(json_str)
|
||||
if isinstance(plan_dict, dict) and "plan" in plan_dict:
|
||||
cache_hits[tenant_id] = cast(SubscriptionPlan, plan_dict)
|
||||
tenant_plans[tenant_id] = cast(SubscriptionPlan, plan_dict)
|
||||
else:
|
||||
cache_misses.append(tenant_id)
|
||||
except json.JSONDecodeError:
|
||||
cache_misses.append(tenant_id)
|
||||
else:
|
||||
cache_misses.append(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"clean_messages: fetch_tenant_plans(cache hits=%s, cache misses=%s)",
|
||||
len(cache_hits),
|
||||
len(cache_misses),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("clean_messages: fetch_tenant_plans(redis mget failed: %s, falling back to API)", e)
|
||||
cache_misses = list(tenant_ids)
|
||||
|
||||
# Step 2: Fetch missing plans from billing API
|
||||
if cache_misses:
|
||||
bulk_plans = BillingService.get_plan_bulk(cache_misses)
|
||||
|
||||
if bulk_plans:
|
||||
plans_to_cache: dict[str, SubscriptionPlan] = {}
|
||||
|
||||
for tenant_id, plan_dict in bulk_plans.items():
|
||||
if isinstance(plan_dict, dict):
|
||||
tenant_plans[tenant_id] = plan_dict # type: ignore
|
||||
plans_to_cache[tenant_id] = plan_dict # type: ignore
|
||||
|
||||
# Step 3: Batch update Redis cache using pipeline
|
||||
if plans_to_cache:
|
||||
try:
|
||||
pipe = redis_client.pipeline()
|
||||
for tenant_id, plan_dict in plans_to_cache.items():
|
||||
redis_key = f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}"
|
||||
# Serialize dict to JSON string
|
||||
json_str = json.dumps(plan_dict)
|
||||
pipe.setex(redis_key, cls.PLAN_CACHE_TTL, json_str)
|
||||
pipe.execute()
|
||||
|
||||
logger.info(
|
||||
"clean_messages: cached %s new tenant plans to Redis",
|
||||
len(plans_to_cache),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("clean_messages: Redis pipeline failed: %s", e)
|
||||
|
||||
return tenant_plans
|
||||
|
||||
@classmethod
|
||||
def _batch_delete_message_relations(cls, session: Session, message_ids: Sequence[str]) -> None:
|
||||
"""
|
||||
Batch delete all related records for given message IDs.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
message_ids: List of message IDs to delete relations for
|
||||
"""
|
||||
if not message_ids:
|
||||
return
|
||||
|
||||
# Delete all related records in batch
|
||||
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
|
||||
|
||||
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))
|
||||
@@ -0,0 +1,996 @@
|
||||
"""
|
||||
Integration tests for SandboxMessagesCleanService using testcontainers.
|
||||
|
||||
This module provides comprehensive integration tests for the sandbox message cleanup service
|
||||
using TestContainers infrastructure with real PostgreSQL and Redis.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
|
||||
|
||||
|
||||
class TestSandboxMessagesCleanServiceIntegration:
|
||||
"""Integration tests for SandboxMessagesCleanService._clean_sandbox_messages_by_time_range."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
"""Clean up database before and after each test to ensure isolation."""
|
||||
yield
|
||||
# Clear all test data in correct order (respecting foreign key constraints)
|
||||
db.session.query(DatasetRetrieverResource).delete()
|
||||
db.session.query(AppAnnotationHitHistory).delete()
|
||||
db.session.query(SavedMessage).delete()
|
||||
db.session.query(MessageFile).delete()
|
||||
db.session.query(MessageAgentThought).delete()
|
||||
db.session.query(MessageChain).delete()
|
||||
db.session.query(MessageAnnotation).delete()
|
||||
db.session.query(MessageFeedback).delete()
|
||||
db.session.query(Message).delete()
|
||||
db.session.query(Conversation).delete()
|
||||
db.session.query(App).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_redis(self):
|
||||
"""Clean up Redis cache before each test."""
|
||||
# Clear tenant plan cache
|
||||
try:
|
||||
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
except Exception:
|
||||
pass # Redis might not be available in some test environments
|
||||
yield
|
||||
# Clean up after test
|
||||
try:
|
||||
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
|
||||
if keys:
|
||||
redis_client.delete(*keys)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_whitelist(self):
|
||||
"""Mock whitelist to return empty list by default."""
|
||||
with patch(
|
||||
"services.sandbox_messages_clean_service.BillingService.get_expired_subscription_cleanup_whitelist"
|
||||
) as mock:
|
||||
mock.return_value = []
|
||||
yield mock
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_billing_enabled(self):
|
||||
"""Mock BILLING_ENABLED to be True for all tests."""
|
||||
with patch("services.sandbox_messages_clean_service.dify_config.BILLING_ENABLED", True):
|
||||
yield
|
||||
|
||||
def _create_account_and_tenant(self, plan="sandbox"):
|
||||
"""Helper to create account and tenant."""
|
||||
fake = Faker()
|
||||
|
||||
account = Account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.flush()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
plan=plan,
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.flush()
|
||||
|
||||
tenant_account_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
db.session.add(tenant_account_join)
|
||||
db.session.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_app(self, tenant, account):
|
||||
"""Helper to create an app."""
|
||||
fake = Faker()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description="Test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
|
||||
return app
|
||||
|
||||
def _create_conversation(self, app):
|
||||
"""Helper to create a conversation."""
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
name="Test conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
def _create_message(self, app, conversation, created_at=None, with_relations=True):
|
||||
"""Helper to create a message with optional related records."""
|
||||
if created_at is None:
|
||||
created_at = datetime.datetime.now()
|
||||
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-3.5-turbo",
|
||||
inputs={},
|
||||
query="Test query",
|
||||
answer="Test answer",
|
||||
message=[{"role": "user", "text": "Test message"}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.flush()
|
||||
|
||||
if with_relations:
|
||||
self._create_message_relations(message)
|
||||
|
||||
db.session.commit()
|
||||
return message
|
||||
|
||||
def _create_message_relations(self, message):
|
||||
"""Helper to create all message-related records."""
|
||||
# MessageFeedback
|
||||
feedback = MessageFeedback(
|
||||
app_id=message.app_id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating="like",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
# MessageAnnotation
|
||||
annotation = MessageAnnotation(
|
||||
app_id=message.app_id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
question="Test question",
|
||||
content="Test annotation",
|
||||
account_id=message.from_account_id,
|
||||
)
|
||||
db.session.add(annotation)
|
||||
|
||||
# MessageChain
|
||||
chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="system",
|
||||
input=json.dumps({"test": "input"}),
|
||||
output=json.dumps({"test": "output"}),
|
||||
)
|
||||
db.session.add(chain)
|
||||
db.session.flush()
|
||||
|
||||
# MessageFile
|
||||
file = MessageFile(
|
||||
message_id=message.id,
|
||||
type="image",
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to="user",
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(file)
|
||||
|
||||
# SavedMessage
|
||||
saved = SavedMessage(
|
||||
app_id=message.app_id,
|
||||
message_id=message.id,
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(saved)
|
||||
|
||||
db.session.flush()
|
||||
|
||||
# AppAnnotationHitHistory
|
||||
hit = AppAnnotationHitHistory(
|
||||
app_id=message.app_id,
|
||||
annotation_id=annotation.id,
|
||||
message_id=message.id,
|
||||
source="annotation",
|
||||
question="Test question",
|
||||
account_id=message.from_account_id,
|
||||
annotation_question="Test annotation question",
|
||||
annotation_content="Test annotation content",
|
||||
)
|
||||
db.session.add(hit)
|
||||
|
||||
# DatasetRetrieverResource
|
||||
resource = DatasetRetrieverResource(
|
||||
message_id=message.id,
|
||||
position=1,
|
||||
dataset_id=str(uuid.uuid4()),
|
||||
dataset_name="Test dataset",
|
||||
document_id=str(uuid.uuid4()),
|
||||
document_name="Test document",
|
||||
data_source_type="upload_file",
|
||||
segment_id=str(uuid.uuid4()),
|
||||
score=0.9,
|
||||
content="Test content",
|
||||
hit_count=1,
|
||||
word_count=10,
|
||||
segment_position=1,
|
||||
index_node_hash="test_hash",
|
||||
retriever_from="dataset",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(resource)
|
||||
|
||||
def test_clean_no_messages_to_delete(self, db_session_with_containers):
|
||||
"""Test cleaning when there are no messages to delete."""
|
||||
# Arrange
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {}
|
||||
|
||||
# Act
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Even with no messages, the loop runs once to check
|
||||
assert stats["batches"] == 1
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_clean_mixed_sandbox_and_paid_tenants(self, db_session_with_containers):
|
||||
"""Test cleaning with mixed sandbox and paid tenants, correctly filtering sandbox messages."""
|
||||
# Arrange - Create sandbox tenants with expired messages
|
||||
sandbox_tenants = []
|
||||
sandbox_message_ids = []
|
||||
for i in range(2):
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
sandbox_tenants.append(tenant)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
# Create 3 expired messages per sandbox tenant
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
for j in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
|
||||
sandbox_message_ids.append(msg.id)
|
||||
|
||||
# Create paid tenants with expired messages (should NOT be deleted)
|
||||
paid_tenants = []
|
||||
paid_message_ids = []
|
||||
for i in range(2):
|
||||
account, tenant = self._create_account_and_tenant(plan="professional")
|
||||
paid_tenants.append(tenant)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
# Create 2 expired messages per paid tenant
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
for j in range(2):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
|
||||
paid_message_ids.append(msg.id)
|
||||
|
||||
# Mock billing service - return plan and expiration_date
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period
|
||||
|
||||
plan_map = {}
|
||||
for tenant in sandbox_tenants:
|
||||
plan_map[tenant.id] = {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_15_days_ago,
|
||||
}
|
||||
for tenant in paid_tenants:
|
||||
plan_map[tenant.id] = {
|
||||
"plan": CloudPlan.PROFESSIONAL,
|
||||
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
|
||||
}
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=7,
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert stats["total_messages"] == 6 # 2 sandbox tenants * 3 messages
|
||||
assert stats["total_deleted"] == 6
|
||||
|
||||
# Only sandbox messages should be deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
|
||||
# Paid messages should remain
|
||||
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
|
||||
|
||||
# Related records of sandbox messages should be deleted
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
|
||||
assert (
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count()
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_clean_with_cursor_pagination(self, db_session_with_containers):
|
||||
"""Test cursor pagination works correctly across multiple batches."""
|
||||
# Arrange - Create sandbox tenant with messages that will span multiple batches
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
# Create 10 expired messages with different timestamps
|
||||
base_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
message_ids = []
|
||||
for i in range(10):
|
||||
msg = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=base_date + datetime.timedelta(hours=i),
|
||||
with_relations=False, # Skip relations for speed
|
||||
)
|
||||
message_ids.append(msg.id)
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {
|
||||
tenant.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
}
|
||||
}
|
||||
|
||||
# Act - Use small batch size to trigger multiple batches
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=3, # Small batch size to test pagination
|
||||
)
|
||||
|
||||
# 5 batches for 10 messages with batch_size=3, the last batch is empty
|
||||
assert stats["batches"] == 5
|
||||
assert stats["total_messages"] == 10
|
||||
assert stats["total_deleted"] == 10
|
||||
|
||||
# All messages should be deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0
|
||||
|
||||
def test_clean_with_dry_run(self, db_session_with_containers):
|
||||
"""Test dry_run mode does not delete messages."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
# Create expired messages
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
message_ids = []
|
||||
for i in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
|
||||
message_ids.append(msg.id)
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {
|
||||
tenant.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
}
|
||||
}
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
dry_run=True, # Dry run mode
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert stats["total_messages"] == 3 # Messages identified
|
||||
assert stats["total_deleted"] == 0 # But NOT deleted
|
||||
|
||||
# All messages should still exist
|
||||
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3
|
||||
# Related records should also still exist
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3
|
||||
|
||||
def test_clean_with_billing_partial_exception_some_known_plans(self, db_session_with_containers):
|
||||
"""Test when billing service fails but returns partial data, only delete known sandbox messages."""
|
||||
# Arrange - Create 3 tenants
|
||||
tenants_data = []
|
||||
for i in range(3):
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg = self._create_message(app, conv, created_at=expired_date)
|
||||
|
||||
tenants_data.append(
|
||||
{
|
||||
"tenant": tenant,
|
||||
"message_id": msg.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock billing service to return partial data with new structure
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
# Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing
|
||||
partial_plan_map = {
|
||||
tenants_data[0]["tenant"].id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
},
|
||||
tenants_data[1]["tenant"].id: {
|
||||
"plan": CloudPlan.PROFESSIONAL,
|
||||
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
|
||||
},
|
||||
# tenants_data[2] is missing from response
|
||||
}
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = partial_plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - Only tenant[0]'s message should be deleted
|
||||
assert stats["total_messages"] == 1
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
# Check which messages were deleted
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
|
||||
) # Sandbox tenant's message deleted
|
||||
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
) # Professional tenant's message preserved
|
||||
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
|
||||
) # Unknown tenant's message preserved (safe default)
|
||||
|
||||
def test_clean_with_billing_exception_no_data(self, db_session_with_containers):
|
||||
"""Test when billing service returns empty data, skip deletion for that batch."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg_id = None
|
||||
msg = self._create_message(app, conv, created_at=expired_date)
|
||||
msg_id = msg.id # Store ID before any operations
|
||||
db.session.commit()
|
||||
|
||||
# Mock billing service to return empty data (simulating failure/no data scenario)
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {} # Empty response, tenant plan unknown
|
||||
|
||||
# Act - Should not raise exception, just skip deletion
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - No messages should be deleted when plan is unknown
|
||||
assert stats["total_messages"] == 0 # Cannot determine sandbox messages
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
# Message should still exist (safe default - don't delete if plan is unknown)
|
||||
assert db.session.query(Message).where(Message.id == msg_id).count() == 1
|
||||
|
||||
def test_redis_cache_for_tenant_plans(self, db_session_with_containers):
|
||||
"""Test that tenant plans are cached in Redis and reused."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
# Create messages in two batches (to test cache reuse)
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
batch1_msgs = []
|
||||
for i in range(2):
|
||||
msg = self._create_message(
|
||||
app, conv, created_at=expired_date + datetime.timedelta(hours=i), with_relations=False
|
||||
)
|
||||
batch1_msgs.append(msg.id)
|
||||
|
||||
batch2_msgs = []
|
||||
for i in range(2):
|
||||
msg = self._create_message(
|
||||
app, conv, created_at=expired_date + datetime.timedelta(hours=10 + i), with_relations=False
|
||||
)
|
||||
batch2_msgs.append(msg.id)
|
||||
|
||||
# Mock billing service with new structure
|
||||
mock_get_plan_bulk = MagicMock(
|
||||
return_value={
|
||||
tenant.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk", mock_get_plan_bulk):
|
||||
# Act - First call
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats1 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=2, # Process 2 messages per batch
|
||||
)
|
||||
|
||||
# Check billing service was called (cache miss)
|
||||
assert mock_get_plan_bulk.call_count == 1
|
||||
first_call_count = mock_get_plan_bulk.call_count
|
||||
|
||||
# Verify Redis cache was populated
|
||||
cache_key = f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}{tenant.id}"
|
||||
cached_plan = redis_client.get(cache_key)
|
||||
assert cached_plan is not None
|
||||
cached_plan_data = json.loads(cached_plan.decode("utf-8"))
|
||||
assert cached_plan_data["plan"] == CloudPlan.SANDBOX
|
||||
assert cached_plan_data["expiration_date"] == -1
|
||||
|
||||
# Act - Second call with same tenant (should use cache)
|
||||
# Create more messages for the same tenant
|
||||
batch3_msgs = []
|
||||
for i in range(2):
|
||||
msg = self._create_message(
|
||||
app, conv, created_at=expired_date + datetime.timedelta(hours=20 + i), with_relations=False
|
||||
)
|
||||
batch3_msgs.append(msg.id)
|
||||
|
||||
stats2 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=2,
|
||||
)
|
||||
|
||||
# Assert - Billing service should not be called again (cache hit)
|
||||
# The call count should be the same
|
||||
assert mock_get_plan_bulk.call_count == first_call_count # Same tenant, should use cache
|
||||
|
||||
# Verify all messages were deleted
|
||||
total_expected = len(batch1_msgs) + len(batch2_msgs) + len(batch3_msgs)
|
||||
assert stats1["total_deleted"] + stats2["total_deleted"] == total_expected
|
||||
|
||||
def test_time_range_filtering(self, db_session_with_containers):
|
||||
"""Test that messages are correctly filtered by [start_from, end_before) time range."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
base_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
|
||||
|
||||
# Create messages: before range, in range, after range
|
||||
msg_before = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
|
||||
with_relations=False,
|
||||
)
|
||||
msg_before_id = msg_before.id
|
||||
|
||||
msg_at_start = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
|
||||
with_relations=False,
|
||||
)
|
||||
msg_at_start_id = msg_at_start.id
|
||||
|
||||
msg_in_range = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
|
||||
with_relations=False,
|
||||
)
|
||||
msg_in_range_id = msg_in_range.id
|
||||
|
||||
msg_at_end = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
|
||||
with_relations=False,
|
||||
)
|
||||
msg_at_end_id = msg_at_end.id
|
||||
|
||||
msg_after = self._create_message(
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
|
||||
with_relations=False,
|
||||
)
|
||||
msg_after_id = msg_after.id
|
||||
|
||||
db.session.commit() # Commit all messages
|
||||
|
||||
# Mock billing service with new structure
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {
|
||||
tenant.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
}
|
||||
}
|
||||
|
||||
# Act - Clean with specific time range [2024-01-10, 2024-01-20)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
start_from=datetime.datetime(2024, 1, 10, 12, 0, 0),
|
||||
end_before=datetime.datetime(2024, 1, 20, 12, 0, 0),
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - Only messages in [start_from, end_before) should be deleted
|
||||
assert stats["total_messages"] == 2 # msg_at_start and msg_in_range
|
||||
assert stats["total_deleted"] == 2
|
||||
|
||||
# Verify specific messages using stored IDs
|
||||
# Before range, kept
|
||||
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1
|
||||
# At start (inclusive), deleted
|
||||
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0
|
||||
# In range, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0
|
||||
# At end (exclusive), kept
|
||||
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1
|
||||
# After range, kept
|
||||
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1
|
||||
|
||||
def test_clean_with_graceful_period_scenarios(self, db_session_with_containers):
|
||||
"""Test cleaning with different graceful period scenarios."""
|
||||
# Arrange - Create 5 different tenants with different plan and expiration scenarios
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
graceful_period = 8 # Use 8 days for this test to validate boundary conditions
|
||||
|
||||
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
|
||||
# Should NOT be deleted
|
||||
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
|
||||
app1 = self._create_app(tenant1, account1)
|
||||
conv1 = self._create_conversation(app1)
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
|
||||
msg1_id = msg1.id # Save ID before potential deletion
|
||||
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
|
||||
|
||||
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
|
||||
# Should be deleted
|
||||
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
|
||||
app2 = self._create_app(tenant2, account2)
|
||||
conv2 = self._create_conversation(app2)
|
||||
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
|
||||
msg2_id = msg2.id # Save ID before potential deletion
|
||||
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
|
||||
|
||||
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
|
||||
# Should be deleted
|
||||
account3, tenant3 = self._create_account_and_tenant(plan="sandbox")
|
||||
app3 = self._create_app(tenant3, account3)
|
||||
conv3 = self._create_conversation(app3)
|
||||
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False)
|
||||
msg3_id = msg3.id # Save ID before potential deletion
|
||||
|
||||
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
|
||||
# Should NOT be deleted
|
||||
account4, tenant4 = self._create_account_and_tenant(plan="professional")
|
||||
app4 = self._create_app(tenant4, account4)
|
||||
conv4 = self._create_conversation(app4)
|
||||
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False)
|
||||
msg4_id = msg4.id # Save ID before potential deletion
|
||||
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
|
||||
|
||||
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
|
||||
# Should NOT be deleted (boundary is exclusive: > graceful_period)
|
||||
account5, tenant5 = self._create_account_and_tenant(plan="sandbox")
|
||||
app5 = self._create_app(tenant5, account5)
|
||||
conv5 = self._create_conversation(app5)
|
||||
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False)
|
||||
msg5_id = msg5.id # Save ID before potential deletion
|
||||
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Mock billing service with all scenarios
|
||||
plan_map = {
|
||||
tenant1.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_5_days_ago,
|
||||
},
|
||||
tenant2.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_10_days_ago,
|
||||
},
|
||||
tenant3.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1,
|
||||
},
|
||||
tenant4.id: {
|
||||
"plan": CloudPlan.PROFESSIONAL,
|
||||
"expiration_date": future_expiration,
|
||||
},
|
||||
tenant5.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_exactly_8_days_ago,
|
||||
},
|
||||
}
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
|
||||
# Mock datetime.now() to use the same timestamp as test setup
|
||||
# This ensures deterministic behavior for boundary conditions (scenario 5)
|
||||
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
|
||||
mock_datetime.datetime.now.return_value = datetime.datetime.fromtimestamp(
|
||||
now_timestamp, tz=datetime.UTC
|
||||
)
|
||||
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
|
||||
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=graceful_period,
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - Only messages from scenario 2 and 3 should be deleted
|
||||
assert stats["total_messages"] == 2
|
||||
assert stats["total_deleted"] == 2
|
||||
|
||||
# Verify each scenario using saved IDs
|
||||
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
|
||||
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept
|
||||
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
|
||||
|
||||
def test_clean_with_tenant_whitelist(self, db_session_with_containers, mock_whitelist):
|
||||
"""Test that whitelisted tenants' messages are not deleted even if they are sandbox and expired."""
|
||||
# Arrange - Create 3 sandbox tenants with expired messages
|
||||
tenants_data = []
|
||||
for i in range(3):
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False)
|
||||
|
||||
tenants_data.append(
|
||||
{
|
||||
"tenant": tenant,
|
||||
"message_id": msg.id,
|
||||
}
|
||||
)
|
||||
|
||||
# Mock billing service - all tenants are sandbox with no subscription
|
||||
plan_map = {
|
||||
tenants_data[0]["tenant"].id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
},
|
||||
tenants_data[1]["tenant"].id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
},
|
||||
tenants_data[2]["tenant"].id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
},
|
||||
}
|
||||
|
||||
# Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not
|
||||
whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id]
|
||||
mock_whitelist.return_value = whitelist
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - Only tenant2's message should be deleted (not whitelisted)
|
||||
assert stats["total_messages"] == 1
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
# Verify tenant0's message still exists (whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
|
||||
|
||||
# Verify tenant1's message still exists (whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
|
||||
# Verify tenant2's message was deleted (not whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
|
||||
|
||||
def test_clean_with_whitelist_and_grace_period(self, db_session_with_containers, mock_whitelist):
|
||||
"""Test that whitelist takes precedence over grace period logic."""
|
||||
# Arrange - Create 2 sandbox tenants
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
# Tenant1: whitelisted, expired beyond grace period
|
||||
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
|
||||
app1 = self._create_app(tenant1, account1)
|
||||
conv1 = self._create_conversation(app1)
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
|
||||
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
|
||||
|
||||
# Tenant2: not whitelisted, within grace period
|
||||
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
|
||||
app2 = self._create_app(tenant2, account2)
|
||||
conv2 = self._create_conversation(app2)
|
||||
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
|
||||
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
|
||||
|
||||
# Mock billing service
|
||||
plan_map = {
|
||||
tenant1.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_30_days_ago, # Beyond grace period
|
||||
},
|
||||
tenant2.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": expired_10_days_ago, # Within grace period
|
||||
},
|
||||
}
|
||||
|
||||
# Setup whitelist - only tenant1 is whitelisted
|
||||
whitelist = [tenant1.id]
|
||||
mock_whitelist.return_value = whitelist
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - No messages should be deleted
|
||||
# tenant1: whitelisted (would be deleted based on grace period, but protected by whitelist)
|
||||
# tenant2: within grace period (not eligible for deletion)
|
||||
assert stats["total_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
# Verify both messages still exist
|
||||
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
|
||||
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period
|
||||
|
||||
def test_clean_with_empty_whitelist(self, db_session_with_containers, mock_whitelist):
|
||||
"""Test that empty whitelist behaves as no whitelist (all eligible messages are deleted)."""
|
||||
# Arrange - Create sandbox tenant with expired messages
|
||||
account, tenant = self._create_account_and_tenant(plan="sandbox")
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg_ids = []
|
||||
for i in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
|
||||
msg_ids.append(msg.id)
|
||||
|
||||
# Mock billing service
|
||||
plan_map = {
|
||||
tenant.id: {
|
||||
"plan": CloudPlan.SANDBOX,
|
||||
"expiration_date": -1, # No previous subscription
|
||||
}
|
||||
}
|
||||
|
||||
# Setup empty whitelist (default behavior from fixture)
|
||||
mock_whitelist.return_value = []
|
||||
|
||||
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = plan_map
|
||||
|
||||
# Act
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
|
||||
end_before=end_before,
|
||||
graceful_period=21, # Use default graceful period
|
||||
batch_size=100,
|
||||
)
|
||||
|
||||
# Assert - All messages should be deleted (no whitelist protection)
|
||||
assert stats["total_messages"] == 3
|
||||
assert stats["total_deleted"] == 3
|
||||
|
||||
# Verify all messages were deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0
|
||||
@@ -149,6 +149,7 @@ class TestWebConversationService:
|
||||
from_end_user_id=user.id if isinstance(user, EndUser) else None,
|
||||
from_account_id=user.id if isinstance(user, Account) else None,
|
||||
dialogue_count=0,
|
||||
is_deleted=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
@@ -2,9 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models import Account, Tenant
|
||||
from models.tools import ApiToolProvider
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
@@ -300,7 +298,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@@ -366,7 +364,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {"auth_type": "none"}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@@ -430,10 +428,21 @@ class TestApiToolManageService:
|
||||
labels = ["test"]
|
||||
|
||||
# Act & Assert: Try to create provider with invalid schema type
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ApiToolManageService.create_api_tool_provider(
|
||||
user_id=account.id,
|
||||
tenant_id=tenant.id,
|
||||
provider_name=provider_name,
|
||||
icon=icon,
|
||||
credentials=credentials,
|
||||
schema_type=schema_type,
|
||||
schema=schema,
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
assert "validation error" in str(exc_info.value)
|
||||
assert "invalid schema type" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
@@ -455,7 +464,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔧"}
|
||||
credentials = {} # Missing auth_type
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
@@ -498,7 +507,7 @@ class TestApiToolManageService:
|
||||
provider_name = fake.company()
|
||||
icon = {"type": "emoji", "value": "🔑"}
|
||||
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
|
||||
schema_type = ApiProviderSchemaType.OPENAPI
|
||||
schema_type = "openapi"
|
||||
schema = self._create_test_openapi_schema()
|
||||
privacy_policy = "https://example.com/privacy"
|
||||
custom_disclaimer = "Custom disclaimer text"
|
||||
|
||||
@@ -1,420 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueuePingEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
PingStreamResponse,
|
||||
StreamEvent,
|
||||
)
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher
|
||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
|
||||
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock(spec=ChatAppGenerateEntity)
|
||||
entity.task_id = "test-task-id"
|
||||
entity.app_id = "test-app-id"
|
||||
# minimal app_config used by pipeline internals
|
||||
entity.app_config = SimpleNamespace(
|
||||
tenant_id="test-tenant-id",
|
||||
app_id="test-app-id",
|
||||
app_mode=AppMode.CHAT,
|
||||
app_model_config_dict={},
|
||||
additional_features=None,
|
||||
sensitive_word_avoidance=None,
|
||||
)
|
||||
# minimal model_conf for LLMResult init
|
||||
entity.model_conf = SimpleNamespace(
|
||||
model="test-model",
|
||||
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
|
||||
credentials={},
|
||||
)
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_manager(self):
|
||||
"""Create a mock queue manager."""
|
||||
manager = Mock(spec=AppQueueManager)
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_cycle_manager(self):
|
||||
"""Create a mock message cycle manager."""
|
||||
manager = Mock()
|
||||
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
|
||||
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
|
||||
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
|
||||
manager.handle_retriever_resources = Mock()
|
||||
manager.handle_annotation_reply.return_value = None
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation(self):
|
||||
"""Create a mock conversation."""
|
||||
conversation = Mock()
|
||||
conversation.id = "test-conversation-id"
|
||||
conversation.mode = "chat"
|
||||
return conversation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message(self):
|
||||
"""Create a mock message."""
|
||||
message = Mock()
|
||||
message.id = "test-message-id"
|
||||
message.created_at = Mock()
|
||||
message.created_at.timestamp.return_value = 1234567890
|
||||
return message
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_state(self):
|
||||
"""Create a mock task state."""
|
||||
task_state = Mock(spec=EasyUITaskState)
|
||||
|
||||
# Create LLM result mock
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.prompt_messages = []
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = ""
|
||||
|
||||
task_state.llm_result = llm_result
|
||||
task_state.answer = ""
|
||||
|
||||
return task_state
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(
|
||||
self,
|
||||
mock_application_generate_entity,
|
||||
mock_queue_manager,
|
||||
mock_conversation,
|
||||
mock_message,
|
||||
mock_message_cycle_manager,
|
||||
mock_task_state,
|
||||
):
|
||||
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
|
||||
):
|
||||
pipeline = EasyUIBasedGenerateTaskPipeline(
|
||||
application_generate_entity=mock_application_generate_entity,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
stream=True,
|
||||
)
|
||||
pipeline._message_cycle_manager = mock_message_cycle_manager
|
||||
pipeline._task_state = mock_task_state
|
||||
return pipeline
|
||||
|
||||
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
|
||||
self, pipeline, mock_message_cycle_manager
|
||||
):
|
||||
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
|
||||
# Setup a minimal LLM chunk event
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "hi"
|
||||
chunk.prompt_messages = []
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
|
||||
|
||||
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with text content."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Hello, world!"
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello, world!"
|
||||
|
||||
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of LLM chunk events with list content."""
|
||||
# Setup
|
||||
text_content = Mock(spec=TextPromptMessageContent)
|
||||
text_content.data = "Hello"
|
||||
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = [text_content, " world!"]
|
||||
chunk.prompt_messages = []
|
||||
|
||||
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = llm_chunk_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
|
||||
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of agent message events."""
|
||||
# Setup
|
||||
chunk = Mock()
|
||||
chunk.delta.message.content = "Agent response"
|
||||
|
||||
agent_message_event = Mock(spec=QueueAgentMessageEvent)
|
||||
agent_message_event.chunk = chunk
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = agent_message_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
# Ensure method under assertion is a mock to track calls
|
||||
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
# Agent messages should use _agent_message_to_stream_response
|
||||
pipeline._agent_message_to_stream_response.assert_called_once_with(
|
||||
answer="Agent response", message_id="test-message-id"
|
||||
)
|
||||
|
||||
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling of message end events."""
|
||||
# Setup
|
||||
llm_result = Mock(spec=RuntimeLLMResult)
|
||||
llm_result.message = Mock()
|
||||
llm_result.message.content = "Final response"
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = llm_result
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert mock_task_state.llm_result == llm_result
|
||||
pipeline._save_message.assert_called_once()
|
||||
pipeline._message_end_to_stream_response.assert_called_once()
|
||||
|
||||
def test_error_event(self, pipeline):
|
||||
"""Test handling of error events."""
|
||||
# Setup
|
||||
error_event = Mock(spec=QueueErrorEvent)
|
||||
error_event.error = Exception("Test error")
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = error_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.handle_error = Mock(return_value=Exception("Test error"))
|
||||
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.handle_error.assert_called_once()
|
||||
pipeline.error_to_stream_response.assert_called_once()
|
||||
|
||||
def test_ping_event(self, pipeline):
|
||||
"""Test handling of ping events."""
|
||||
# Setup
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
pipeline.ping_stream_response.assert_called_once()
|
||||
|
||||
def test_file_event(self, pipeline, mock_message_cycle_manager):
|
||||
"""Test handling of file events."""
|
||||
# Setup
|
||||
file_event = Mock(spec=QueueMessageFileEvent)
|
||||
file_event.message_file_id = "file-id"
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = file_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
file_response = Mock(spec=MessageFileStreamResponse)
|
||||
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 1
|
||||
assert responses[0] == file_response
|
||||
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
|
||||
|
||||
def test_publisher_is_called_with_messages(self, pipeline):
|
||||
"""Test that publisher publishes messages when provided."""
|
||||
# Setup
|
||||
publisher = Mock(spec=AppGeneratorTTSPublisher)
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = ping_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
# Called once with message and once with None at the end
|
||||
assert publisher.publish.call_count == 2
|
||||
publisher.publish.assert_any_call(mock_queue_message)
|
||||
publisher.publish.assert_any_call(None)
|
||||
|
||||
def test_trace_manager_passed_to_save_message(self, pipeline):
|
||||
"""Test that trace manager is passed to _save_message."""
|
||||
# Setup
|
||||
trace_manager = Mock(spec=TraceQueueManager)
|
||||
|
||||
message_end_event = Mock(spec=QueueMessageEndEvent)
|
||||
message_end_event.llm_result = None
|
||||
|
||||
mock_queue_message = Mock()
|
||||
mock_queue_message.event = message_end_event
|
||||
pipeline.queue_manager.listen.return_value = [mock_queue_message]
|
||||
|
||||
pipeline._save_message = Mock()
|
||||
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
|
||||
|
||||
# Patch db.engine used inside pipeline for session creation
|
||||
with patch(
|
||||
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
|
||||
):
|
||||
# Execute
|
||||
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
|
||||
|
||||
# Assert
|
||||
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
|
||||
|
||||
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
|
||||
"""Test handling multiple events in sequence."""
|
||||
# Setup
|
||||
chunk1 = Mock()
|
||||
chunk1.delta.message.content = "Hello"
|
||||
chunk1.prompt_messages = []
|
||||
|
||||
chunk2 = Mock()
|
||||
chunk2.delta.message.content = " world!"
|
||||
chunk2.prompt_messages = []
|
||||
|
||||
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event1.chunk = chunk1
|
||||
|
||||
ping_event = Mock(spec=QueuePingEvent)
|
||||
|
||||
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
|
||||
llm_chunk_event2.chunk = chunk2
|
||||
|
||||
mock_queue_messages = [
|
||||
Mock(event=llm_chunk_event1),
|
||||
Mock(event=ping_event),
|
||||
Mock(event=llm_chunk_event2),
|
||||
]
|
||||
pipeline.queue_manager.listen.return_value = mock_queue_messages
|
||||
|
||||
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
|
||||
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
|
||||
|
||||
# Execute
|
||||
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
|
||||
|
||||
# Assert
|
||||
assert len(responses) == 3
|
||||
assert mock_task_state.llm_result.message.content == "Hello world!"
|
||||
|
||||
# Verify calls to message_to_stream_response
|
||||
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
|
||||
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
@@ -1,166 +0,0 @@
|
||||
"""Unit tests for the message cycle manager optimization."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
|
||||
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
|
||||
|
||||
class TestMessageCycleManagerOptimization:
|
||||
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_application_generate_entity(self):
|
||||
"""Create a mock application generate entity."""
|
||||
entity = Mock()
|
||||
entity.task_id = "test-task-id"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def message_cycle_manager(self, mock_application_generate_entity):
|
||||
"""Create a message cycle manager instance."""
|
||||
task_state = Mock()
|
||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||
|
||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and no message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None
|
||||
|
||||
# Execute
|
||||
with current_app.app_context():
|
||||
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Assert
|
||||
assert result == StreamEvent.MESSAGE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
|
||||
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
# Setup mock session and message file
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_message_file = Mock()
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = mock_message_file
|
||||
|
||||
# Execute: compute event type once, then pass to message_to_stream_response
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE_FILE
|
||||
mock_session.query.return_value.scalar.assert_called_once()
|
||||
|
||||
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
|
||||
"""Test that message_to_stream_response skips database query when event_type is provided."""
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
# Execute with event_type provided
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
# Should not query database when event_type is provided
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
|
||||
"""Test message_to_stream_response with from_variable_selector parameter."""
|
||||
result = message_cycle_manager.message_to_stream_response(
|
||||
answer="Hello world",
|
||||
message_id="test-message-id",
|
||||
from_variable_selector=["var1", "var2"],
|
||||
event_type=StreamEvent.MESSAGE,
|
||||
)
|
||||
|
||||
assert isinstance(result, MessageStreamResponse)
|
||||
assert result.answer == "Hello world"
|
||||
assert result.id == "test-message-id"
|
||||
assert result.from_variable_selector == ["var1", "var2"]
|
||||
assert result.event == StreamEvent.MESSAGE
|
||||
|
||||
def test_optimization_usage_example(self, message_cycle_manager):
|
||||
"""Test the optimization pattern that should be used by callers."""
|
||||
# Step 1: Get event type once (this queries database)
|
||||
with (
|
||||
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
|
||||
):
|
||||
mock_session = Mock()
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
# Current implementation uses session.query(...).scalar()
|
||||
mock_session.query.return_value.scalar.return_value = None # No files
|
||||
with current_app.app_context():
|
||||
event_type = message_cycle_manager.get_message_event_type("test-message-id")
|
||||
|
||||
# Should query database once
|
||||
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
|
||||
assert event_type == StreamEvent.MESSAGE
|
||||
|
||||
# Step 2: Use event_type for multiple calls (no additional queries)
|
||||
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
|
||||
mock_session_class.return_value.__enter__.return_value = Mock()
|
||||
|
||||
chunk1_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 1", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
chunk2_response = message_cycle_manager.message_to_stream_response(
|
||||
answer="Chunk 2", message_id="test-message-id", event_type=event_type
|
||||
)
|
||||
|
||||
# Should not query database again
|
||||
mock_session_class.assert_not_called()
|
||||
|
||||
assert chunk1_response.event == StreamEvent.MESSAGE
|
||||
assert chunk2_response.event == StreamEvent.MESSAGE
|
||||
assert chunk1_response.answer == "Chunk 1"
|
||||
assert chunk2_response.answer == "Chunk 2"
|
||||
@@ -901,13 +901,6 @@ class TestFixedRecursiveCharacterTextSplitter:
|
||||
# Verify no empty chunks
|
||||
assert all(len(chunk) > 0 for chunk in result)
|
||||
|
||||
def test_double_slash_n(self):
|
||||
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
|
||||
separator = "\\n\\n---\\n\\n"
|
||||
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
|
||||
chunks = splitter.split_text(data)
|
||||
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Metadata Preservation
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
@@ -47,16 +46,14 @@ def make_start_node(user_inputs, variables):
|
||||
|
||||
|
||||
def test_json_object_valid_schema():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@@ -68,7 +65,7 @@ def test_json_object_valid_schema():
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
|
||||
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
result = node._run()
|
||||
@@ -77,23 +74,12 @@ def test_json_object_valid_schema():
|
||||
|
||||
|
||||
def test_json_object_invalid_json_string():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
json_schema=schema,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -102,21 +88,38 @@ def test_json_object_invalid_json_string():
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
|
||||
def test_json_object_valid_json_but_not_object(value):
|
||||
variables = [
|
||||
VariableEntity(
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
)
|
||||
]
|
||||
|
||||
user_inputs = {"profile": value}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
|
||||
def test_json_object_does_not_match_schema():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@@ -129,7 +132,7 @@ def test_json_object_does_not_match_schema():
|
||||
]
|
||||
|
||||
# age is a string, which violates the schema (expects number)
|
||||
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
|
||||
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@@ -138,16 +141,14 @@ def test_json_object_does_not_match_schema():
|
||||
|
||||
|
||||
def test_json_object_missing_required_schema_field():
|
||||
schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {"type": "number"},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["age", "name"],
|
||||
}
|
||||
|
||||
variables = [
|
||||
VariableEntity(
|
||||
@@ -160,7 +161,7 @@ def test_json_object_missing_required_schema_field():
|
||||
]
|
||||
|
||||
# Missing required field "name"
|
||||
user_inputs = {"profile": json.dumps({"age": 20})}
|
||||
user_inputs = {"profile": {"age": 20}}
|
||||
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
@@ -213,7 +214,7 @@ def test_json_object_optional_variable_not_provided():
|
||||
variable="profile",
|
||||
label="profile",
|
||||
type=VariableEntityType.JSON_OBJECT,
|
||||
required=True,
|
||||
required=False,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -222,5 +223,5 @@ def test_json_object_optional_variable_not_provided():
|
||||
node = make_start_node(user_inputs, variables)
|
||||
|
||||
# Current implementation raises a validation error even when the variable is optional
|
||||
with pytest.raises(ValueError, match="profile is required in input form"):
|
||||
with pytest.raises(ValueError, match="profile must be a JSON object"):
|
||||
node._run()
|
||||
|
||||
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
Unit tests for SandboxMessagesCleanService.
|
||||
|
||||
This module tests parameter validation, method invocation, and error handling
|
||||
without database dependencies (using mocks).
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
|
||||
|
||||
|
||||
class MockMessage:
|
||||
"""Mock message object for testing."""
|
||||
|
||||
def __init__(self, id: str, app_id: str, created_at: datetime.datetime | None = None):
|
||||
self.id = id
|
||||
self.app_id = app_id
|
||||
self.created_at = created_at or datetime.datetime.now()
|
||||
|
||||
|
||||
class TestFilterExpiredSandboxMessages:
|
||||
"""Unit tests for _filter_expired_sandbox_messages method."""
|
||||
|
||||
def test_filter_missing_tenant_mapping(self):
|
||||
"""Test that messages with missing app-to-tenant mapping are excluded."""
|
||||
# Arrange
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"),
|
||||
MockMessage("msg2", "app2"),
|
||||
]
|
||||
app_to_tenant = {} # No mapping
|
||||
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=1000000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_filter_missing_tenant_plan(self):
|
||||
"""Test that messages with missing tenant plan are excluded."""
|
||||
# Arrange
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"),
|
||||
MockMessage("msg2", "app2"),
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
}
|
||||
tenant_plans = {} # No plans
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=1000000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_filter_no_previous_subscription(self):
|
||||
"""Test that messages with no previous subscription (expiration_date=-1) are deleted."""
|
||||
# Arrange
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"),
|
||||
MockMessage("msg2", "app2"),
|
||||
MockMessage("msg3", "app3"),
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=1000000,
|
||||
)
|
||||
|
||||
# Assert - all messages should be deleted
|
||||
assert set(result) == {"msg1", "msg2", "msg3"}
|
||||
|
||||
def test_filter_all_within_grace_period(self):
|
||||
"""Test that no messages are deleted when all are within grace period."""
|
||||
# Arrange
|
||||
now = 1000000
|
||||
# All expired recently (within 8 day grace period)
|
||||
expired_1_day_ago = now - (1 * 24 * 60 * 60)
|
||||
expired_3_days_ago = now - (3 * 24 * 60 * 60)
|
||||
expired_7_days_ago = now - (7 * 24 * 60 * 60)
|
||||
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"),
|
||||
MockMessage("msg2", "app2"),
|
||||
MockMessage("msg3", "app3"),
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_3_days_ago},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
# Assert - no messages should be deleted
|
||||
assert result == []
|
||||
|
||||
def test_filter_partial_expired_beyond_grace_period(self):
|
||||
"""Test filtering when some messages expired beyond grace period."""
|
||||
# Arrange
|
||||
now = 1000000
|
||||
graceful_period = 8
|
||||
|
||||
# Different expiration scenarios
|
||||
expired_5_days_ago = now - (5 * 24 * 60 * 60) # Within grace - keep
|
||||
expired_10_days_ago = now - (10 * 24 * 60 * 60) # Beyond grace - delete
|
||||
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Beyond grace - delete
|
||||
expired_exactly_8_days_ago = now - (8 * 24 * 60 * 60) # Exactly at boundary - keep
|
||||
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond - delete
|
||||
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"), # Within grace
|
||||
MockMessage("msg2", "app2"), # Beyond grace
|
||||
MockMessage("msg3", "app3"), # Beyond grace
|
||||
MockMessage("msg4", "app4"), # No subscription - delete
|
||||
MockMessage("msg5", "app5"), # Exactly at boundary
|
||||
MockMessage("msg6", "app6"), # Just beyond grace
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
"app4": "tenant4",
|
||||
"app5": "tenant5",
|
||||
"app6": "tenant6",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_10_days_ago},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
|
||||
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant5": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
|
||||
"tenant6": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=graceful_period,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
# Assert - msg2, msg3, msg4, msg6 should be deleted
|
||||
# msg1 and msg5 are within/at grace period boundary
|
||||
assert set(result) == {"msg2", "msg3", "msg4", "msg6"}
|
||||
|
||||
def test_filter_complex_mixed_scenario(self):
|
||||
"""Test complex scenario with mixed plans, expirations, and missing mappings."""
|
||||
# Arrange
|
||||
now = 1000000
|
||||
sandbox_expired_old = now - (15 * 24 * 60 * 60) # 15 days ago - beyond grace
|
||||
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # 3 days ago - within grace
|
||||
future_expiration = now + (30 * 24 * 60 * 60) # 30 days in future - active paid plan
|
||||
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"), # Sandbox, no subscription - delete
|
||||
MockMessage("msg2", "app2"), # Sandbox, expired old - delete
|
||||
MockMessage("msg3", "app3"), # Sandbox, within grace - keep
|
||||
MockMessage("msg4", "app4"), # Team plan, active - keep
|
||||
MockMessage("msg5", "app5"), # No tenant mapping - keep
|
||||
MockMessage("msg6", "app6"), # No plan info - keep
|
||||
MockMessage("msg7", "app7"), # Sandbox, expired old - delete
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
"app4": "tenant4",
|
||||
"app6": "tenant6", # Has mapping but no plan
|
||||
"app7": "tenant7",
|
||||
# app5 has no mapping
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
|
||||
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
|
||||
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
|
||||
# tenant6 has no plan
|
||||
}
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
# Assert - only sandbox expired beyond grace period and no subscription
|
||||
assert set(result) == {"msg1", "msg2", "msg7"}
|
||||
|
||||
def test_filter_empty_inputs(self):
|
||||
"""Test filtering with empty inputs returns empty list."""
|
||||
# Arrange - empty messages
|
||||
result1 = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=[],
|
||||
app_to_tenant={"app1": "tenant1"},
|
||||
tenant_plans={"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}},
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
current_timestamp=1000000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result1 == []
|
||||
|
||||
def test_filter_uses_default_timestamp(self):
|
||||
"""Test that method uses current time when timestamp not provided."""
|
||||
# Arrange
|
||||
messages = [MockMessage("msg1", "app1")]
|
||||
app_to_tenant = {"app1": "tenant1"}
|
||||
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
|
||||
|
||||
# Act - don't provide current_timestamp
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=[],
|
||||
graceful_period_days=8,
|
||||
# current_timestamp not provided - should use datetime.now()
|
||||
)
|
||||
|
||||
# Assert - should still work and return msg1 (no subscription)
|
||||
assert result == ["msg1"]
|
||||
|
||||
def test_filter_with_whitelist(self):
|
||||
"""Test that messages from whitelisted tenants are excluded from deletion."""
|
||||
# Arrange
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"), # Whitelisted tenant - should be kept
|
||||
MockMessage("msg2", "app2"), # Not whitelisted - should be deleted
|
||||
MockMessage("msg3", "app3"), # Whitelisted tenant - should be kept
|
||||
MockMessage("msg4", "app4"), # Not whitelisted - should be deleted
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
"app4": "tenant4",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
}
|
||||
tenant_whitelist = ["tenant1", "tenant3"] # Whitelist tenant1 and tenant3
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=tenant_whitelist,
|
||||
graceful_period_days=8,
|
||||
current_timestamp=1000000,
|
||||
)
|
||||
|
||||
# Assert - only msg2 and msg4 should be deleted (not whitelisted)
|
||||
assert set(result) == {"msg2", "msg4"}
|
||||
|
||||
def test_filter_with_whitelist_and_grace_period(self):
|
||||
"""Test whitelist takes precedence over grace period logic."""
|
||||
# Arrange
|
||||
now = 1000000
|
||||
expired_long_ago = now - (30 * 24 * 60 * 60) # Expired 30 days ago
|
||||
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"), # Whitelisted, expired long ago - should be kept
|
||||
MockMessage("msg2", "app2"), # Not whitelisted, expired long ago - should be deleted
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
|
||||
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
|
||||
}
|
||||
tenant_whitelist = ["tenant1"] # Only tenant1 is whitelisted
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=tenant_whitelist,
|
||||
graceful_period_days=8,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
# Assert - only msg2 should be deleted
|
||||
assert result == ["msg2"]
|
||||
|
||||
def test_filter_whitelist_with_non_sandbox_plans(self):
|
||||
"""Test that whitelist only affects sandbox plan messages."""
|
||||
# Arrange
|
||||
now = 1000000
|
||||
future_expiration = now + (30 * 24 * 60 * 60)
|
||||
|
||||
messages = [
|
||||
MockMessage("msg1", "app1"), # Sandbox, whitelisted - kept
|
||||
MockMessage("msg2", "app2"), # Team plan, whitelisted - kept (not sandbox)
|
||||
MockMessage("msg3", "app3"), # Sandbox, not whitelisted - deleted
|
||||
]
|
||||
app_to_tenant = {
|
||||
"app1": "tenant1",
|
||||
"app2": "tenant2",
|
||||
"app3": "tenant3",
|
||||
}
|
||||
tenant_plans = {
|
||||
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
|
||||
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
|
||||
}
|
||||
tenant_whitelist = ["tenant1", "tenant2"]
|
||||
|
||||
# Act
|
||||
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
|
||||
messages=messages,
|
||||
app_to_tenant=app_to_tenant,
|
||||
tenant_plans=tenant_plans,
|
||||
tenant_whitelist=tenant_whitelist,
|
||||
graceful_period_days=8,
|
||||
current_timestamp=now,
|
||||
)
|
||||
|
||||
# Assert - only msg3 should be deleted (sandbox, not whitelisted)
|
||||
assert result == ["msg3"]
|
||||
|
||||
|
||||
class TestCleanSandboxMessagesByTimeRange:
|
||||
"""Unit tests for clean_sandbox_messages_by_time_range method."""
|
||||
|
||||
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
|
||||
def test_valid_time_range_and_args(self, mock_clean):
|
||||
"""Test with valid time range and other parameters."""
|
||||
# Arrange
|
||||
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
|
||||
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
|
||||
batch_size = 500
|
||||
dry_run = True
|
||||
|
||||
mock_clean.return_value = {
|
||||
"batches": 5,
|
||||
"total_messages": 100,
|
||||
"total_deleted": 100,
|
||||
}
|
||||
|
||||
# Act
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
# Assert, expected no exception raised
|
||||
mock_clean.assert_called_once_with(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
graceful_period=21,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
|
||||
def test_with_default_args(self, mock_clean):
|
||||
"""Test with default args."""
|
||||
# Arrange
|
||||
start_from = datetime.datetime(2024, 1, 1)
|
||||
end_before = datetime.datetime(2024, 2, 1)
|
||||
|
||||
mock_clean.return_value = {
|
||||
"batches": 2,
|
||||
"total_messages": 50,
|
||||
"total_deleted": 0,
|
||||
}
|
||||
|
||||
# Act
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_clean.assert_called_once_with(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
graceful_period=21,
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
def test_invalid_time_range(self):
|
||||
"""Test invalid time range raises ValueError."""
|
||||
# Arrange
|
||||
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
|
||||
|
||||
# Act & Assert start equals end
|
||||
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=same_time,
|
||||
end_before=same_time,
|
||||
)
|
||||
|
||||
# Arrange
|
||||
start_from = datetime.datetime(2024, 12, 31)
|
||||
end_before = datetime.datetime(2024, 1, 1)
|
||||
|
||||
# Act & Assert start after end
|
||||
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
)
|
||||
|
||||
def test_invalid_batch_size(self):
|
||||
"""Test invalid batch_size raises ValueError."""
|
||||
# Arrange
|
||||
start_from = datetime.datetime(2024, 1, 1)
|
||||
end_before = datetime.datetime(2024, 2, 1)
|
||||
|
||||
# Act & Assert batch_size = 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=0,
|
||||
)
|
||||
|
||||
# Act & Assert batch_size < 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
batch_size=-100,
|
||||
)
|
||||
|
||||
|
||||
class TestCleanSandboxMessagesByDays:
|
||||
"""Unit tests for clean_sandbox_messages_by_days method."""
|
||||
|
||||
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
|
||||
def test_default_days(self, mock_clean):
|
||||
"""Test with default 30 days."""
|
||||
# Arrange
|
||||
mock_clean.return_value = {"batches": 3, "total_messages": 75, "total_deleted": 75}
|
||||
|
||||
# Act
|
||||
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
|
||||
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_days()
|
||||
|
||||
# Assert
|
||||
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
|
||||
mock_clean.assert_called_once_with(
|
||||
end_before=expected_end_before,
|
||||
start_from=None,
|
||||
graceful_period=21,
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
|
||||
def test_custom_days(self, mock_clean):
|
||||
"""Test with custom number of days."""
|
||||
# Arrange
|
||||
custom_days = 90
|
||||
mock_clean.return_value = {"batches": 10, "total_messages": 500, "total_deleted": 500}
|
||||
|
||||
# Act
|
||||
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
|
||||
|
||||
result = SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=custom_days)
|
||||
|
||||
# Assert
|
||||
expected_end_before = fixed_now - datetime.timedelta(days=custom_days)
|
||||
mock_clean.assert_called_once_with(
|
||||
end_before=expected_end_before,
|
||||
start_from=None,
|
||||
graceful_period=21,
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
|
||||
def test_zero_days(self, mock_clean):
|
||||
"""Test with days=0 (clean all messages before now)."""
|
||||
# Arrange
|
||||
mock_clean.return_value = {"batches": 0, "total_messages": 0, "total_deleted": 0}
|
||||
|
||||
# Act
|
||||
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
|
||||
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=0)
|
||||
|
||||
# Assert
|
||||
expected_end_before = fixed_now - datetime.timedelta(days=0) # same as fixed_now
|
||||
mock_clean.assert_called_once_with(
|
||||
end_before=expected_end_before,
|
||||
start_from=None,
|
||||
graceful_period=21,
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
def test_invalid_batch_size(self):
|
||||
"""Test invalid batch_size raises ValueError."""
|
||||
# Act & Assert batch_size = 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
|
||||
days=30,
|
||||
batch_size=0,
|
||||
)
|
||||
|
||||
# Act & Assert batch_size < 0
|
||||
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
|
||||
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
|
||||
days=30,
|
||||
batch_size=-500,
|
||||
)
|
||||
@@ -1369,10 +1369,7 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
|
||||
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
|
||||
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
|
||||
# Plugin Daemon side timeout (configure to match the API side below)
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT=600
|
||||
# API side timeout (configure to match the Plugin Daemon side above)
|
||||
PLUGIN_DAEMON_TIMEOUT=600.0
|
||||
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
PIP_MIRROR_URL=
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ services:
|
||||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
||||
@@ -591,7 +591,6 @@ x-shared-env: &shared-api-worker-env
|
||||
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
|
||||
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
|
||||
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
|
||||
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
|
||||
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
|
||||
@@ -703,7 +702,6 @@ services:
|
||||
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
|
||||
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
|
||||
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
|
||||
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
|
||||
depends_on:
|
||||
init_permissions:
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
/**
|
||||
* Mock for ky HTTP client
|
||||
* This mock is used to avoid ESM issues in Jest tests
|
||||
*/
|
||||
|
||||
type KyResponse = {
|
||||
ok: boolean
|
||||
status: number
|
||||
statusText: string
|
||||
headers: Headers
|
||||
json: jest.Mock
|
||||
text: jest.Mock
|
||||
blob: jest.Mock
|
||||
arrayBuffer: jest.Mock
|
||||
clone: jest.Mock
|
||||
}
|
||||
|
||||
type KyInstance = jest.Mock & {
|
||||
get: jest.Mock
|
||||
post: jest.Mock
|
||||
put: jest.Mock
|
||||
patch: jest.Mock
|
||||
delete: jest.Mock
|
||||
head: jest.Mock
|
||||
create: jest.Mock
|
||||
extend: jest.Mock
|
||||
stop: symbol
|
||||
}
|
||||
|
||||
const createResponse = (data: unknown = {}, status = 200): KyResponse => {
|
||||
const response: KyResponse = {
|
||||
ok: status >= 200 && status < 300,
|
||||
status,
|
||||
statusText: status === 200 ? 'OK' : 'Error',
|
||||
headers: new Headers(),
|
||||
json: jest.fn().mockResolvedValue(data),
|
||||
text: jest.fn().mockResolvedValue(JSON.stringify(data)),
|
||||
blob: jest.fn().mockResolvedValue(new Blob()),
|
||||
arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)),
|
||||
clone: jest.fn(),
|
||||
}
|
||||
// Ensure clone returns a new response-like object, not the same instance
|
||||
response.clone.mockImplementation(() => createResponse(data, status))
|
||||
return response
|
||||
}
|
||||
|
||||
const createKyInstance = (): KyInstance => {
|
||||
const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance
|
||||
|
||||
// HTTP methods
|
||||
instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
|
||||
|
||||
// Create new instance with custom options
|
||||
instance.create = jest.fn().mockImplementation(() => createKyInstance())
|
||||
instance.extend = jest.fn().mockImplementation(() => createKyInstance())
|
||||
|
||||
// Stop method for AbortController
|
||||
instance.stop = Symbol('stop')
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
const ky = createKyInstance()
|
||||
|
||||
export default ky
|
||||
export { ky }
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import s from './style.module.css'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useStore } from '@/app/components/app/store'
|
||||
import AppSideBar from '@/app/components/app-sidebar'
|
||||
import type { NavIcon } from '@/app/components/app-sidebar/navLink'
|
||||
|
||||
@@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
|
||||
import type { Dayjs } from 'dayjs'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { formatToLocalTime } from '@/utils/format'
|
||||
import { useI18N } from '@/context/i18n'
|
||||
import Picker from '@/app/components/base/date-and-time-picker/date-picker'
|
||||
|
||||
@@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import dayjs from 'dayjs'
|
||||
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
const today = dayjs()
|
||||
|
||||
@@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
|
||||
|
||||
import type { PopupProps } from './config-popup'
|
||||
import ConfigPopup from './config-popup'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
|
||||
@@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
const I18N_PREFIX = 'app.tracing'
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Input from '@/app/components/base/input'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
|
||||
import { TracingProvider } from './type'
|
||||
import TracingIcon from './tracing-icon'
|
||||
import ConfigButton from './config-button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
} from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { TracingProvider } from './type'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
|
||||
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import ExtraInfo from '@/app/components/datasets/extra-info'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
export type IAppDetailLayoutProps = {
|
||||
children: React.ReactNode
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
export default function SignInLayout({ children }: any) {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from 'classnames'
|
||||
import { RiCheckboxCircleFill } from '@remixicon/react'
|
||||
import { useCountDown } from 'ahooks'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import type { PropsWithChildren } from 'react'
|
||||
|
||||
@@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading'
|
||||
import MailAndCodeAuth from './components/mail-and-code-auth'
|
||||
import MailAndPasswordAuth from './components/mail-and-password-auth'
|
||||
import SSOAuth from './components/sso-auth'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { LicenseStatus } from '@/types/feature'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import Header from '@/app/signin/_header'
|
||||
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { AppContextProvider } from '@/context/app-context'
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
'use client'
|
||||
import { useEffect } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useSWR from 'swr'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
import { invitationCheck } from '@/service/common'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useInvitationCheck } from '@/service/use-common'
|
||||
|
||||
const ActivateForm = () => {
|
||||
useDocumentTitle('')
|
||||
@@ -26,21 +26,19 @@ const ActivateForm = () => {
|
||||
token,
|
||||
},
|
||||
}
|
||||
const { data: checkRes } = useInvitationCheck({
|
||||
...checkParams.params,
|
||||
token: token || undefined,
|
||||
}, true)
|
||||
|
||||
useEffect(() => {
|
||||
if (checkRes?.is_valid) {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
const { email, workspace_id } = checkRes.data
|
||||
params.set('email', encodeURIComponent(email))
|
||||
params.set('workspace_id', encodeURIComponent(workspace_id))
|
||||
params.set('invite_token', encodeURIComponent(token as string))
|
||||
router.replace(`/signin?${params.toString()}`)
|
||||
}
|
||||
}, [checkRes, router, searchParams, token])
|
||||
const { data: checkRes } = useSWR(checkParams, invitationCheck, {
|
||||
revalidateOnFocus: false,
|
||||
onSuccess(data) {
|
||||
if (data.is_valid) {
|
||||
const params = new URLSearchParams(searchParams)
|
||||
const { email, workspace_id } = data.data
|
||||
params.set('email', encodeURIComponent(email))
|
||||
params.set('workspace_id', encodeURIComponent(workspace_id))
|
||||
params.set('invite_token', encodeURIComponent(token as string))
|
||||
router.replace(`/signin?${params.toString()}`)
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
return (
|
||||
<div className={
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import React from 'react'
|
||||
import Header from '../signin/_header'
|
||||
import ActivateForm from './activateForm'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
|
||||
const Activate = () => {
|
||||
|
||||
@@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
|
||||
import type { Operation } from './app-operations'
|
||||
import AppOperations from './app-operations'
|
||||
import dynamic from 'next/dynamic'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {
|
||||
|
||||
@@ -16,7 +16,7 @@ import AppInfo from './app-info'
|
||||
import NavLink from './navLink'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import type { NavIcon } from './navLink'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
|
||||
import ActionButton from '../../base/action-button'
|
||||
import { RiMoreFill } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Menu from './menu'
|
||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
|
||||
@@ -1,379 +0,0 @@
|
||||
import React from 'react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import DatasetInfo from './index'
|
||||
import Dropdown from './dropdown'
|
||||
import Menu from './menu'
|
||||
import MenuItem from './menu-item'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import {
|
||||
ChunkingMode,
|
||||
DataSourceType,
|
||||
DatasetPermission,
|
||||
} from '@/models/datasets'
|
||||
import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import { RiEditLine } from '@remixicon/react'
|
||||
|
||||
let mockDataset: DataSet
|
||||
let mockIsDatasetOperator = false
|
||||
const mockReplace = jest.fn()
|
||||
const mockInvalidDatasetList = jest.fn()
|
||||
const mockInvalidDatasetDetail = jest.fn()
|
||||
const mockExportPipeline = jest.fn()
|
||||
const mockCheckIsUsedInApp = jest.fn()
|
||||
const mockDeleteDataset = jest.fn()
|
||||
|
||||
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
|
||||
id: 'dataset-1',
|
||||
name: 'Dataset Name',
|
||||
indexing_status: 'completed',
|
||||
icon_info: {
|
||||
icon: '📙',
|
||||
icon_background: '#FFF4ED',
|
||||
icon_type: 'emoji',
|
||||
icon_url: '',
|
||||
},
|
||||
description: 'Dataset description',
|
||||
permission: DatasetPermission.onlyMe,
|
||||
data_source_type: DataSourceType.FILE,
|
||||
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
|
||||
created_by: 'user-1',
|
||||
updated_by: 'user-1',
|
||||
updated_at: 1690000000,
|
||||
app_count: 0,
|
||||
doc_form: ChunkingMode.text,
|
||||
document_count: 1,
|
||||
total_document_count: 1,
|
||||
word_count: 1000,
|
||||
provider: 'internal',
|
||||
embedding_model: 'text-embedding-3',
|
||||
embedding_model_provider: 'openai',
|
||||
embedding_available: true,
|
||||
retrieval_model_dict: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
retrieval_model: {
|
||||
search_method: RETRIEVE_METHOD.semantic,
|
||||
reranking_enable: false,
|
||||
reranking_model: {
|
||||
reranking_provider_name: '',
|
||||
reranking_model_name: '',
|
||||
},
|
||||
top_k: 5,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: 0,
|
||||
},
|
||||
tags: [],
|
||||
external_knowledge_info: {
|
||||
external_knowledge_id: '',
|
||||
external_knowledge_api_id: '',
|
||||
external_knowledge_api_name: '',
|
||||
external_knowledge_api_endpoint: '',
|
||||
},
|
||||
external_retrieval_model: {
|
||||
top_k: 0,
|
||||
score_threshold: 0,
|
||||
score_threshold_enabled: false,
|
||||
},
|
||||
built_in_field_enabled: false,
|
||||
runtime_mode: 'rag_pipeline',
|
||||
enable_api: false,
|
||||
is_multimodal: false,
|
||||
...overrides,
|
||||
})
|
||||
|
||||
jest.mock('next/navigation', () => ({
|
||||
useRouter: () => ({
|
||||
replace: mockReplace,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/dataset-detail', () => ({
|
||||
useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/app-context', () => ({
|
||||
useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) =>
|
||||
selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/knowledge/use-dataset', () => ({
|
||||
datasetDetailQueryKeyPrefix: ['dataset', 'detail'],
|
||||
useInvalidDatasetList: () => mockInvalidDatasetList,
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-base', () => ({
|
||||
useInvalid: () => mockInvalidDatasetDetail,
|
||||
}))
|
||||
|
||||
jest.mock('@/service/use-pipeline', () => ({
|
||||
useExportPipelineDSL: () => ({
|
||||
mutateAsync: mockExportPipeline,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/datasets', () => ({
|
||||
checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args),
|
||||
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
|
||||
}))
|
||||
|
||||
jest.mock('@/hooks/use-knowledge', () => ({
|
||||
useKnowledge: () => ({
|
||||
formatIndexingTechniqueAndMethod: () => 'indexing-technique',
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/datasets/rename-modal', () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
show,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: {
|
||||
show: boolean
|
||||
onClose: () => void
|
||||
onSuccess?: () => void
|
||||
}) => {
|
||||
if (!show)
|
||||
return null
|
||||
return (
|
||||
<div data-testid="rename-modal">
|
||||
<button type="button" onClick={onSuccess}>Success</button>
|
||||
<button type="button" onClick={onClose}>Close</button>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
const openMenu = async (user: ReturnType<typeof userEvent.setup>) => {
|
||||
const trigger = screen.getByRole('button')
|
||||
await user.click(trigger)
|
||||
}
|
||||
|
||||
describe('DatasetInfo', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset()
|
||||
mockIsDatasetOperator = false
|
||||
})
|
||||
|
||||
// Rendering of dataset summary details based on expand and dataset state.
|
||||
describe('Rendering', () => {
|
||||
it('should show dataset details when expanded', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ is_published: true })
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('Dataset Name')).toBeInTheDocument()
|
||||
expect(screen.getByText('Dataset description')).toBeInTheDocument()
|
||||
expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument()
|
||||
expect(screen.getByText('indexing-technique')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show external tag when provider is external', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ provider: 'external', is_published: false })
|
||||
render(<DatasetInfo expand />)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
|
||||
expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide detailed fields when collapsed', () => {
|
||||
// Arrange
|
||||
render(<DatasetInfo expand={false} />)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('Dataset description')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('MenuItem', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
// Event handling for menu item interactions.
|
||||
describe('Interactions', () => {
|
||||
it('should call handler when clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const handleClick = jest.fn()
|
||||
// Arrange
|
||||
render(<MenuItem name="Edit" Icon={RiEditLine} handleClick={handleClick} />)
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByText('Edit'))
|
||||
|
||||
// Assert
|
||||
expect(handleClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Menu', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset()
|
||||
})
|
||||
|
||||
// Rendering of menu options based on runtime mode and delete visibility.
|
||||
describe('Rendering', () => {
|
||||
it('should show edit, export, and delete options when rag pipeline and deletable', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ runtime_mode: 'rag_pipeline' })
|
||||
render(
|
||||
<Menu
|
||||
showDelete
|
||||
openRenameModal={jest.fn()}
|
||||
handleExportPipeline={jest.fn()}
|
||||
detectIsUsedByApp={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
|
||||
expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument()
|
||||
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should hide export and delete options when not rag pipeline and not deletable', () => {
|
||||
// Arrange
|
||||
mockDataset = createDataset({ runtime_mode: 'general' })
|
||||
render(
|
||||
<Menu
|
||||
showDelete={false}
|
||||
openRenameModal={jest.fn()}
|
||||
handleExportPipeline={jest.fn()}
|
||||
detectIsUsedByApp={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
|
||||
expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Dropdown', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' })
|
||||
mockIsDatasetOperator = false
|
||||
mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' })
|
||||
mockCheckIsUsedInApp.mockResolvedValue({ is_using: false })
|
||||
mockDeleteDataset.mockResolvedValue({})
|
||||
if (!('createObjectURL' in URL)) {
|
||||
Object.defineProperty(URL, 'createObjectURL', {
|
||||
value: jest.fn(),
|
||||
writable: true,
|
||||
})
|
||||
}
|
||||
if (!('revokeObjectURL' in URL)) {
|
||||
Object.defineProperty(URL, 'revokeObjectURL', {
|
||||
value: jest.fn(),
|
||||
writable: true,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Rendering behavior based on workspace role.
|
||||
describe('Rendering', () => {
|
||||
it('should hide delete option when user is dataset operator', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
mockIsDatasetOperator = true
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
|
||||
// Assert
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// User interactions that trigger modals and exports.
|
||||
describe('Interactions', () => {
|
||||
it('should open rename modal when edit is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId('rename-modal')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should export pipeline when export is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL')
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockExportPipeline).toHaveBeenCalledWith({
|
||||
pipelineId: 'pipeline-1',
|
||||
include: false,
|
||||
})
|
||||
})
|
||||
expect(createObjectURLSpy).toHaveBeenCalledTimes(1)
|
||||
expect(anchorClickSpy).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should show delete confirmation when delete is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should delete dataset and redirect when confirm is clicked', async () => {
|
||||
const user = userEvent.setup()
|
||||
// Arrange
|
||||
render(<Dropdown expand />)
|
||||
|
||||
// Act
|
||||
await openMenu(user)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' }))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1')
|
||||
})
|
||||
expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1)
|
||||
expect(mockReplace).toHaveBeenCalledWith('/datasets')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import { DOC_FORM_TEXT } from '@/models/datasets'
|
||||
import { useKnowledge } from '@/hooks/use-knowledge'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Dropdown from './dropdown'
|
||||
|
||||
type DatasetInfoProps = {
|
||||
|
||||
@@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon'
|
||||
import Divider from '../base/divider'
|
||||
import NavLink from './navLink'
|
||||
import type { NavIcon } from './navLink'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||
import Effect from '../base/effect'
|
||||
import Dropdown from './dataset-info/dropdown'
|
||||
|
||||
@@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Divider from '../base/divider'
|
||||
import { useHover, useKeyPress } from 'ahooks'
|
||||
import ToggleButton from './toggle-button'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import React from 'react'
|
||||
import { useSelectedLayoutSegment } from 'next/navigation'
|
||||
import Link from 'next/link'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import classNames from '@/utils/classnames'
|
||||
import type { RemixiconComponentType } from '@remixicon/react'
|
||||
|
||||
export type NavIcon = React.ComponentType<
|
||||
@@ -42,7 +42,7 @@ const NavLink = ({
|
||||
const NavIcon = isActive ? iconMap.selected : iconMap.normal
|
||||
|
||||
const renderIcon = () => (
|
||||
<div className={cn(mode !== 'expand' && '-ml-1')}>
|
||||
<div className={classNames(mode !== 'expand' && '-ml-1')}>
|
||||
<NavIcon className="h-4 w-4 shrink-0" aria-hidden="true" />
|
||||
</div>
|
||||
)
|
||||
@@ -53,17 +53,21 @@ const NavLink = ({
|
||||
key={name}
|
||||
type='button'
|
||||
disabled
|
||||
className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
|
||||
'pl-3 pr-1')}
|
||||
className={classNames(
|
||||
'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
|
||||
'pl-3 pr-1',
|
||||
)}
|
||||
title={mode === 'collapse' ? name : ''}
|
||||
aria-disabled
|
||||
>
|
||||
{renderIcon()}
|
||||
<span
|
||||
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
className={classNames(
|
||||
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
mode === 'expand'
|
||||
? 'ml-2 max-w-none opacity-100'
|
||||
: 'ml-0 max-w-0 opacity-0')}
|
||||
: 'ml-0 max-w-0 opacity-0',
|
||||
)}
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
@@ -75,18 +79,22 @@ const NavLink = ({
|
||||
<Link
|
||||
key={name}
|
||||
href={href}
|
||||
className={cn(isActive
|
||||
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
|
||||
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
|
||||
'flex h-8 items-center rounded-lg pl-3 pr-1')}
|
||||
className={classNames(
|
||||
isActive
|
||||
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
|
||||
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
|
||||
'flex h-8 items-center rounded-lg pl-3 pr-1',
|
||||
)}
|
||||
title={mode === 'collapse' ? name : ''}
|
||||
>
|
||||
{renderIcon()}
|
||||
<span
|
||||
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
className={classNames(
|
||||
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
|
||||
mode === 'expand'
|
||||
? 'ml-2 max-w-none opacity-100'
|
||||
: 'ml-0 max-w-0 opacity-0')}
|
||||
: 'ml-0 max-w-0 opacity-0',
|
||||
)}
|
||||
>
|
||||
{name}
|
||||
</span>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import React from 'react'
|
||||
import Button from '../base/button'
|
||||
import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Tooltip from '../base/tooltip'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { getKeyboardKeyNameBySystem } from '../workflow/utils'
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import React from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import BatchAction from './batch-action'
|
||||
|
||||
describe('BatchAction', () => {
|
||||
const baseProps = {
|
||||
selectedIds: ['1', '2', '3'],
|
||||
onBatchDelete: jest.fn(),
|
||||
onCancel: jest.fn(),
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should show the selected count and trigger cancel action', () => {
|
||||
render(<BatchAction {...baseProps} className='custom-class' />)
|
||||
|
||||
expect(screen.getByText('3')).toBeInTheDocument()
|
||||
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
|
||||
|
||||
expect(baseProps.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should confirm before running batch delete', async () => {
|
||||
const onBatchDelete = jest.fn().mockResolvedValue(undefined)
|
||||
render(<BatchAction {...baseProps} onBatchDelete={onBatchDelete} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' }))
|
||||
await screen.findByText('appAnnotation.list.delete.title')
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1])
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onBatchDelete).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useBoolean } from 'ahooks'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import classNames from '@/utils/classnames'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
|
||||
const i18nPrefix = 'appAnnotation.batchAction'
|
||||
@@ -38,7 +38,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
||||
setIsNotDeleting()
|
||||
}
|
||||
return (
|
||||
<div className={cn('pointer-events-none flex w-full justify-center', className)}>
|
||||
<div className={classNames('pointer-events-none flex w-full justify-center', className)}>
|
||||
<div className='pointer-events-auto flex items-center gap-x-1 rounded-[10px] border border-components-actionbar-border-accent bg-components-actionbar-bg-accent p-1 shadow-xl shadow-shadow-shadow-5 backdrop-blur-[5px]'>
|
||||
<div className='inline-flex items-center gap-x-2 py-1 pl-2 pr-3'>
|
||||
<span className='flex h-5 w-5 items-center justify-center rounded-md bg-text-accent px-1 py-0.5 text-xs font-medium text-text-primary-on-surface'>
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import CSVDownload from './csv-downloader'
|
||||
import I18nContext from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import type { Locale } from '@/i18n-config'
|
||||
|
||||
const downloaderProps: any[] = []
|
||||
|
||||
jest.mock('react-papaparse', () => ({
|
||||
useCSVDownloader: jest.fn(() => ({
|
||||
CSVDownloader: ({ children, ...props }: any) => {
|
||||
downloaderProps.push(props)
|
||||
return <div data-testid="mock-csv-downloader">{children}</div>
|
||||
},
|
||||
Type: { Link: 'link' },
|
||||
})),
|
||||
}))
|
||||
|
||||
const renderWithLocale = (locale: Locale) => {
|
||||
return render(
|
||||
<I18nContext.Provider value={{
|
||||
locale,
|
||||
i18n: {},
|
||||
setLocaleOnClient: jest.fn().mockResolvedValue(undefined),
|
||||
}}
|
||||
>
|
||||
<CSVDownload />
|
||||
</I18nContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
describe('CSVDownload', () => {
|
||||
const englishTemplate = [
|
||||
['question', 'answer'],
|
||||
['question1', 'answer1'],
|
||||
['question2', 'answer2'],
|
||||
]
|
||||
const chineseTemplate = [
|
||||
['问题', '答案'],
|
||||
['问题 1', '答案 1'],
|
||||
['问题 2', '答案 2'],
|
||||
]
|
||||
|
||||
beforeEach(() => {
|
||||
downloaderProps.length = 0
|
||||
})
|
||||
|
||||
it('should render the structure preview and pass English template data by default', () => {
|
||||
renderWithLocale('en-US' as Locale)
|
||||
|
||||
expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument()
|
||||
expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument()
|
||||
|
||||
expect(downloaderProps[0]).toMatchObject({
|
||||
filename: 'template-en-US',
|
||||
type: 'link',
|
||||
bom: true,
|
||||
data: englishTemplate,
|
||||
})
|
||||
})
|
||||
|
||||
it('should switch to the Chinese template when locale matches the secondary language', () => {
|
||||
const locale = LanguagesSupported[1] as Locale
|
||||
renderWithLocale(locale)
|
||||
|
||||
expect(downloaderProps[0]).toMatchObject({
|
||||
filename: `template-${locale}`,
|
||||
data: chineseTemplate,
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { RiDeleteBinLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
import React from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import BatchModal, { ProcessStatus } from './index'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation'
|
||||
import type { IBatchModalProps } from './index'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
jest.mock('@/app/components/base/toast', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
notify: jest.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
jest.mock('@/service/annotation', () => ({
|
||||
annotationBatchImport: jest.fn(),
|
||||
checkAnnotationBatchImportProgress: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('./csv-downloader', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="csv-downloader-stub" />,
|
||||
}))
|
||||
|
||||
let lastUploadedFile: File | undefined
|
||||
|
||||
jest.mock('./csv-uploader', () => ({
|
||||
__esModule: true,
|
||||
default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => (
|
||||
<div>
|
||||
<button
|
||||
data-testid="mock-uploader"
|
||||
onClick={() => {
|
||||
lastUploadedFile = new File(['question,answer'], 'batch.csv', { type: 'text/csv' })
|
||||
updateFile(lastUploadedFile)
|
||||
}}
|
||||
>
|
||||
upload
|
||||
</button>
|
||||
{file && <span data-testid="selected-file">{file.name}</span>}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/billing/annotation-full', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="annotation-full" />,
|
||||
}))
|
||||
|
||||
const mockNotify = Toast.notify as jest.Mock
|
||||
const useProviderContextMock = useProviderContext as jest.Mock
|
||||
const annotationBatchImportMock = annotationBatchImport as jest.Mock
|
||||
const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock
|
||||
|
||||
const renderComponent = (props: Partial<IBatchModalProps> = {}) => {
|
||||
const mergedProps: IBatchModalProps = {
|
||||
appId: 'app-id',
|
||||
isShow: true,
|
||||
onCancel: jest.fn(),
|
||||
onAdded: jest.fn(),
|
||||
...props,
|
||||
}
|
||||
return {
|
||||
...render(<BatchModal {...mergedProps} />),
|
||||
props: mergedProps,
|
||||
}
|
||||
}
|
||||
|
||||
describe('BatchModal', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
lastUploadedFile = undefined
|
||||
useProviderContextMock.mockReturnValue({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 0 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should disable run action and show billing hint when annotation quota is full', () => {
|
||||
useProviderContextMock.mockReturnValue({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 10 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: true,
|
||||
})
|
||||
|
||||
renderComponent()
|
||||
|
||||
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('should reset uploader state when modal closes and allow manual cancellation', () => {
|
||||
const { rerender, props } = renderComponent()
|
||||
|
||||
fireEvent.click(screen.getByTestId('mock-uploader'))
|
||||
expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv')
|
||||
|
||||
rerender(<BatchModal {...props} isShow={false} />)
|
||||
rerender(<BatchModal {...props} isShow />)
|
||||
|
||||
expect(screen.queryByTestId('selected-file')).toBeNull()
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }))
|
||||
expect(props.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('should submit the csv file, poll status, and notify when import completes', async () => {
|
||||
jest.useFakeTimers()
|
||||
const { props } = renderComponent()
|
||||
const fileTrigger = screen.getByTestId('mock-uploader')
|
||||
fireEvent.click(fileTrigger)
|
||||
|
||||
const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })
|
||||
expect(runButton).not.toBeDisabled()
|
||||
|
||||
annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
|
||||
checkAnnotationBatchImportProgressMock
|
||||
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
|
||||
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED })
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(runButton)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(annotationBatchImportMock).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData
|
||||
expect(formData.get('file')).toBe(lastUploadedFile)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
jest.runOnlyPendingTimers()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'appAnnotation.batchModal.completed',
|
||||
})
|
||||
expect(props.onAdded).toHaveBeenCalledTimes(1)
|
||||
expect(props.onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
jest.useRealTimers()
|
||||
})
|
||||
})
|
||||
@@ -245,7 +245,7 @@ describe('EditItem', () => {
|
||||
expect(mockSave).toHaveBeenCalledWith('Test save content')
|
||||
})
|
||||
|
||||
it('should show delete option and restore original content when delete is clicked', async () => {
|
||||
it('should show delete option when content changes', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn().mockResolvedValue(undefined)
|
||||
const props = {
|
||||
@@ -267,13 +267,7 @@ describe('EditItem', () => {
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
// Assert
|
||||
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
|
||||
expect(await screen.findByText('common.operation.delete')).toBeInTheDocument()
|
||||
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
|
||||
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
|
||||
expect(mockSave).toHaveBeenCalledWith('Modified content')
|
||||
})
|
||||
|
||||
it('should handle keyboard interactions in edit mode', async () => {
|
||||
@@ -399,68 +393,5 @@ describe('EditItem', () => {
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Test content')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle save failure gracefully in edit mode', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed'))
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onSave: mockSave,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Act
|
||||
render(<EditItem {...props} />)
|
||||
|
||||
// Enter edit mode and save (should fail)
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.type(textarea, 'New content')
|
||||
|
||||
// Save should fail but not throw
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
|
||||
// Assert - Should remain in edit mode when save fails
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
expect(mockSave).toHaveBeenCalledWith('New content')
|
||||
})
|
||||
|
||||
it('should handle delete action failure gracefully', async () => {
|
||||
// Arrange
|
||||
const mockSave = jest.fn()
|
||||
.mockResolvedValueOnce(undefined) // First save succeeds
|
||||
.mockRejectedValueOnce(new Error('Delete failed')) // Delete fails
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onSave: mockSave,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Act
|
||||
render(<EditItem {...props} />)
|
||||
|
||||
// Edit content to show delete button
|
||||
await user.click(screen.getByText('common.operation.edit'))
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'Modified content')
|
||||
|
||||
// Save to create new content
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
|
||||
await screen.findByText('common.operation.delete')
|
||||
|
||||
// Click delete (should fail but not throw)
|
||||
await user.click(screen.getByText('common.operation.delete'))
|
||||
|
||||
// Assert - Delete action should handle error gracefully
|
||||
expect(mockSave).toHaveBeenCalledTimes(2)
|
||||
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
|
||||
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
|
||||
|
||||
// When delete fails, the delete button should still be visible (state not changed)
|
||||
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
|
||||
expect(screen.getByText('Modified content')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react'
|
||||
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
export enum EditItemType {
|
||||
Query = 'query',
|
||||
@@ -52,14 +52,8 @@ const EditItem: FC<Props> = ({
|
||||
}, [content])
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
await onSave(newContent)
|
||||
setIsEdit(false)
|
||||
}
|
||||
catch {
|
||||
// Keep edit mode open when save fails
|
||||
// Error notification is handled by the parent component
|
||||
}
|
||||
await onSave(newContent)
|
||||
setIsEdit(false)
|
||||
}
|
||||
|
||||
const handleCancel = () => {
|
||||
@@ -102,16 +96,9 @@ const EditItem: FC<Props> = ({
|
||||
<div className='mr-2'>·</div>
|
||||
<div
|
||||
className='flex cursor-pointer items-center space-x-1'
|
||||
onClick={async () => {
|
||||
try {
|
||||
await onSave(content)
|
||||
// Only update UI state after successful delete
|
||||
setNewContent(content)
|
||||
}
|
||||
catch {
|
||||
// Delete action failed - error is already handled by parent
|
||||
// UI state remains unchanged, user can retry
|
||||
}
|
||||
onClick={() => {
|
||||
setNewContent(content)
|
||||
onSave(content)
|
||||
}}
|
||||
>
|
||||
<div className='h-3.5 w-3.5'>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast'
|
||||
import EditAnnotationModal from './index'
|
||||
@@ -405,276 +405,4 @@ describe('EditAnnotationModal', () => {
|
||||
expect(editLinks).toHaveLength(1) // Only answer should have edit button
|
||||
})
|
||||
})
|
||||
|
||||
// Error Handling (CRITICAL for coverage)
|
||||
describe('Error Handling', () => {
|
||||
it('should show error toast and skip callbacks when addAnnotation fails', async () => {
|
||||
// Arrange
|
||||
const mockOnAdded = jest.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onAdded: mockOnAdded,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Mock API failure
|
||||
mockAddAnnotation.mockRejectedValueOnce(new Error('API Error'))
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Find and click edit link for query
|
||||
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
|
||||
await user.click(editLinks[0])
|
||||
|
||||
// Find textarea and enter new content
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'New query content')
|
||||
|
||||
// Click save button
|
||||
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
|
||||
await user.click(saveButton)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({
|
||||
message: 'API Error',
|
||||
type: 'error',
|
||||
})
|
||||
})
|
||||
expect(mockOnAdded).not.toHaveBeenCalled()
|
||||
|
||||
// Verify edit mode remains open (textarea should still be visible)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show fallback error message when addAnnotation error has no message', async () => {
|
||||
// Arrange
|
||||
const mockOnAdded = jest.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
onAdded: mockOnAdded,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
mockAddAnnotation.mockRejectedValueOnce({})
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
|
||||
await user.click(editLinks[0])
|
||||
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'New query content')
|
||||
|
||||
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
|
||||
await user.click(saveButton)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({
|
||||
message: 'common.api.actionFailed',
|
||||
type: 'error',
|
||||
})
|
||||
})
|
||||
expect(mockOnAdded).not.toHaveBeenCalled()
|
||||
|
||||
// Verify edit mode remains open (textarea should still be visible)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show error toast and skip callbacks when editAnnotation fails', async () => {
|
||||
// Arrange
|
||||
const mockOnEdited = jest.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
annotationId: 'test-annotation-id',
|
||||
messageId: 'test-message-id',
|
||||
onEdited: mockOnEdited,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Mock API failure
|
||||
mockEditAnnotation.mockRejectedValueOnce(new Error('API Error'))
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Edit query content
|
||||
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
|
||||
await user.click(editLinks[0])
|
||||
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'Modified query')
|
||||
|
||||
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
|
||||
await user.click(saveButton)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({
|
||||
message: 'API Error',
|
||||
type: 'error',
|
||||
})
|
||||
})
|
||||
expect(mockOnEdited).not.toHaveBeenCalled()
|
||||
|
||||
// Verify edit mode remains open (textarea should still be visible)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show fallback error message when editAnnotation error is not an Error instance', async () => {
|
||||
// Arrange
|
||||
const mockOnEdited = jest.fn()
|
||||
const props = {
|
||||
...defaultProps,
|
||||
annotationId: 'test-annotation-id',
|
||||
messageId: 'test-message-id',
|
||||
onEdited: mockOnEdited,
|
||||
}
|
||||
const user = userEvent.setup()
|
||||
|
||||
mockEditAnnotation.mockRejectedValueOnce('oops')
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
|
||||
await user.click(editLinks[0])
|
||||
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'Modified query')
|
||||
|
||||
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
|
||||
await user.click(saveButton)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({
|
||||
message: 'common.api.actionFailed',
|
||||
type: 'error',
|
||||
})
|
||||
})
|
||||
expect(mockOnEdited).not.toHaveBeenCalled()
|
||||
|
||||
// Verify edit mode remains open (textarea should still be visible)
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Billing & Plan Features
|
||||
describe('Billing & Plan Features', () => {
|
||||
it('should show createdAt time when provided', () => {
|
||||
// Arrange
|
||||
const props = {
|
||||
...defaultProps,
|
||||
annotationId: 'test-annotation-id',
|
||||
createdAt: 1701381000, // 2023-12-01 10:30:00
|
||||
}
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Assert - Check that the formatted time appears somewhere in the component
|
||||
const container = screen.getByRole('dialog')
|
||||
expect(container).toHaveTextContent('2023-12-01 10:30:00')
|
||||
})
|
||||
|
||||
it('should not show createdAt when not provided', () => {
|
||||
// Arrange
|
||||
const props = {
|
||||
...defaultProps,
|
||||
annotationId: 'test-annotation-id',
|
||||
// createdAt is undefined
|
||||
}
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Assert - Should not contain any timestamp
|
||||
const container = screen.getByRole('dialog')
|
||||
expect(container).not.toHaveTextContent('2023-12-01 10:30:00')
|
||||
})
|
||||
|
||||
it('should display remove section when annotationId exists', () => {
|
||||
// Arrange
|
||||
const props = {
|
||||
...defaultProps,
|
||||
annotationId: 'test-annotation-id',
|
||||
}
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Assert - Should have remove functionality
|
||||
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Toast Notifications (Success)
|
||||
describe('Toast Notifications', () => {
|
||||
it('should show success notification when save operation completes', async () => {
|
||||
// Arrange
|
||||
const props = { ...defaultProps }
|
||||
const user = userEvent.setup()
|
||||
|
||||
// Act
|
||||
render(<EditAnnotationModal {...props} />)
|
||||
|
||||
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
|
||||
await user.click(editLinks[0])
|
||||
|
||||
const textarea = screen.getByRole('textbox')
|
||||
await user.clear(textarea)
|
||||
await user.type(textarea, 'Updated query')
|
||||
|
||||
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
|
||||
await user.click(saveButton)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith({
|
||||
message: 'common.api.actionSuccess',
|
||||
type: 'success',
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// React.memo Performance Testing
|
||||
describe('React.memo Performance', () => {
|
||||
it('should not re-render when props are the same', () => {
|
||||
// Arrange
|
||||
const props = { ...defaultProps }
|
||||
const { rerender } = render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Act - Re-render with same props
|
||||
rerender(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Assert - Component should still be visible (no errors thrown)
|
||||
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should re-render when props change', () => {
|
||||
// Arrange
|
||||
const props = { ...defaultProps }
|
||||
const { rerender } = render(<EditAnnotationModal {...props} />)
|
||||
|
||||
// Act - Re-render with different props
|
||||
const newProps = { ...props, query: 'New query content' }
|
||||
rerender(<EditAnnotationModal {...newProps} />)
|
||||
|
||||
// Assert - Should show new content
|
||||
expect(screen.getByText('New query content')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -53,39 +53,27 @@ const EditAnnotationModal: FC<Props> = ({
|
||||
postQuery = editedContent
|
||||
else
|
||||
postAnswer = editedContent
|
||||
try {
|
||||
if (!isAdd) {
|
||||
await editAnnotation(appId, annotationId, {
|
||||
message_id: messageId,
|
||||
question: postQuery,
|
||||
answer: postAnswer,
|
||||
})
|
||||
onEdited(postQuery, postAnswer)
|
||||
}
|
||||
else {
|
||||
const res = await addAnnotation(appId, {
|
||||
question: postQuery,
|
||||
answer: postAnswer,
|
||||
message_id: messageId,
|
||||
})
|
||||
onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer)
|
||||
}
|
||||
if (!isAdd) {
|
||||
await editAnnotation(appId, annotationId, {
|
||||
message_id: messageId,
|
||||
question: postQuery,
|
||||
answer: postAnswer,
|
||||
})
|
||||
onEdited(postQuery, postAnswer)
|
||||
}
|
||||
else {
|
||||
const res: any = await addAnnotation(appId, {
|
||||
question: postQuery,
|
||||
answer: postAnswer,
|
||||
message_id: messageId,
|
||||
})
|
||||
onAdded(res.id, res.account?.name, postQuery, postAnswer)
|
||||
}
|
||||
|
||||
Toast.notify({
|
||||
message: t('common.api.actionSuccess') as string,
|
||||
type: 'success',
|
||||
})
|
||||
}
|
||||
catch (error) {
|
||||
const fallbackMessage = t('common.api.actionFailed') as string
|
||||
const message = error instanceof Error && error.message ? error.message : fallbackMessage
|
||||
Toast.notify({
|
||||
message,
|
||||
type: 'error',
|
||||
})
|
||||
// Re-throw to preserve edit mode behavior for UI components
|
||||
throw error
|
||||
}
|
||||
Toast.notify({
|
||||
message: t('common.api.actionSuccess') as string,
|
||||
type: 'success',
|
||||
})
|
||||
}
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
import React from 'react'
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import EmptyElement from './empty-element'
|
||||
|
||||
describe('EmptyElement', () => {
|
||||
it('should render the empty state copy and supporting icon', () => {
|
||||
const { container } = render(<EmptyElement />)
|
||||
|
||||
expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument()
|
||||
expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument()
|
||||
expect(container.querySelector('svg')).not.toBeNull()
|
||||
})
|
||||
})
|
||||
@@ -1,70 +0,0 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Filter, { type QueryParam } from './filter'
|
||||
import useSWR from 'swr'
|
||||
|
||||
jest.mock('swr', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/log', () => ({
|
||||
fetchAnnotationsCount: jest.fn(),
|
||||
}))
|
||||
|
||||
const mockUseSWR = useSWR as unknown as jest.Mock
|
||||
|
||||
describe('Filter', () => {
|
||||
const appId = 'app-1'
|
||||
const childContent = 'child-content'
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render nothing until annotation count is fetched', () => {
|
||||
mockUseSWR.mockReturnValue({ data: undefined })
|
||||
|
||||
const { container } = render(
|
||||
<Filter
|
||||
appId={appId}
|
||||
queryParams={{ keyword: '' }}
|
||||
setQueryParams={jest.fn()}
|
||||
>
|
||||
<div>{childContent}</div>
|
||||
</Filter>,
|
||||
)
|
||||
|
||||
expect(container.firstChild).toBeNull()
|
||||
expect(mockUseSWR).toHaveBeenCalledWith(
|
||||
{ url: `/apps/${appId}/annotations/count` },
|
||||
expect.any(Function),
|
||||
)
|
||||
})
|
||||
|
||||
it('should propagate keyword changes and clearing behavior', () => {
|
||||
mockUseSWR.mockReturnValue({ data: { total: 20 } })
|
||||
const queryParams: QueryParam = { keyword: 'prefill' }
|
||||
const setQueryParams = jest.fn()
|
||||
|
||||
const { container } = render(
|
||||
<Filter
|
||||
appId={appId}
|
||||
queryParams={queryParams}
|
||||
setQueryParams={setQueryParams}
|
||||
>
|
||||
<div>{childContent}</div>
|
||||
</Filter>,
|
||||
)
|
||||
|
||||
const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement
|
||||
fireEvent.change(input, { target: { value: 'updated' } })
|
||||
expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' })
|
||||
|
||||
const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement
|
||||
fireEvent.click(clearButton)
|
||||
expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' })
|
||||
|
||||
expect(container).toHaveTextContent(childContent)
|
||||
})
|
||||
})
|
||||
@@ -1,439 +0,0 @@
|
||||
import * as React from 'react'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import type { ComponentProps } from 'react'
|
||||
import HeaderOptions from './index'
|
||||
import I18NContext from '@/context/i18n'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import type { AnnotationItemBasic } from '../type'
|
||||
import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation'
|
||||
|
||||
jest.mock('@headlessui/react', () => {
|
||||
type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void }
|
||||
type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void }
|
||||
const PopoverContext = React.createContext<PopoverContextValue | null>(null)
|
||||
const MenuContext = React.createContext<MenuContextValue | null>(null)
|
||||
|
||||
const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => {
|
||||
const [open, setOpen] = React.useState(false)
|
||||
const value = React.useMemo(() => ({ open, setOpen }), [open])
|
||||
return (
|
||||
<PopoverContext.Provider value={value}>
|
||||
{typeof children === 'function' ? children({ open }) : children}
|
||||
</PopoverContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref<HTMLButtonElement>) => {
|
||||
const context = React.useContext(PopoverContext)
|
||||
const handleClick = () => {
|
||||
context?.setOpen(!context.open)
|
||||
onClick?.()
|
||||
}
|
||||
return (
|
||||
<button
|
||||
ref={ref}
|
||||
type="button"
|
||||
aria-expanded={context?.open ?? false}
|
||||
onClick={handleClick}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
})
|
||||
|
||||
const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref<HTMLDivElement>) => {
|
||||
const context = React.useContext(PopoverContext)
|
||||
if (!context?.open) return null
|
||||
const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children
|
||||
return (
|
||||
<div ref={ref} {...props}>
|
||||
{content}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
const Menu = ({ children }: { children: React.ReactNode }) => {
|
||||
const [open, setOpen] = React.useState(false)
|
||||
const value = React.useMemo(() => ({ open, setOpen }), [open])
|
||||
return (
|
||||
<MenuContext.Provider value={value}>
|
||||
{children}
|
||||
</MenuContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => {
|
||||
const context = React.useContext(MenuContext)
|
||||
const handleClick = () => {
|
||||
context?.setOpen(!context.open)
|
||||
onClick?.()
|
||||
}
|
||||
return (
|
||||
<button type="button" aria-expanded={context?.open ?? false} onClick={handleClick} {...props}>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => {
|
||||
const context = React.useContext(MenuContext)
|
||||
if (!context?.open) return null
|
||||
return (
|
||||
<div {...props}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => {
|
||||
if (open === false) return null
|
||||
return (
|
||||
<div role="dialog" className={className}>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => (
|
||||
<div className={className} onClick={onClick}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => (
|
||||
<div className={className} {...props}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => (
|
||||
<div className={className} {...props}>
|
||||
{children}
|
||||
</div>
|
||||
),
|
||||
Popover,
|
||||
PopoverButton,
|
||||
PopoverPanel,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuItems,
|
||||
Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children}</> : null),
|
||||
TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}</>,
|
||||
}
|
||||
})
|
||||
|
||||
let lastCSVDownloaderProps: Record<string, unknown> | undefined
|
||||
const mockCSVDownloader = jest.fn(({ children, ...props }) => {
|
||||
lastCSVDownloaderProps = props
|
||||
return (
|
||||
<div data-testid="csv-downloader">
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
jest.mock('react-papaparse', () => ({
|
||||
useCSVDownloader: () => ({
|
||||
CSVDownloader: (props: any) => mockCSVDownloader(props),
|
||||
Type: { Link: 'link' },
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/annotation', () => ({
|
||||
fetchExportAnnotationList: jest.fn(),
|
||||
clearAllAnnotations: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 0 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: false,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/app/components/billing/annotation-full', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="annotation-full" />,
|
||||
}))
|
||||
|
||||
type HeaderOptionsProps = ComponentProps<typeof HeaderOptions>
|
||||
|
||||
const renderComponent = (
|
||||
props: Partial<HeaderOptionsProps> = {},
|
||||
locale: string = LanguagesSupported[0] as string,
|
||||
) => {
|
||||
const defaultProps: HeaderOptionsProps = {
|
||||
appId: 'test-app-id',
|
||||
onAdd: jest.fn(),
|
||||
onAdded: jest.fn(),
|
||||
controlUpdateList: 0,
|
||||
...props,
|
||||
}
|
||||
|
||||
return render(
|
||||
<I18NContext.Provider
|
||||
value={{
|
||||
locale,
|
||||
i18n: {},
|
||||
setLocaleOnClient: jest.fn(),
|
||||
}}
|
||||
>
|
||||
<HeaderOptions {...defaultProps} />
|
||||
</I18NContext.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
const openOperationsPopover = async (user: ReturnType<typeof userEvent.setup>) => {
|
||||
const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement
|
||||
expect(trigger).toBeTruthy()
|
||||
await user.click(trigger)
|
||||
}
|
||||
|
||||
const expandExportMenu = async (user: ReturnType<typeof userEvent.setup>) => {
|
||||
await openOperationsPopover(user)
|
||||
const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport')
|
||||
const exportButton = exportLabel.closest('button') as HTMLButtonElement
|
||||
expect(exportButton).toBeTruthy()
|
||||
await user.click(exportButton)
|
||||
}
|
||||
|
||||
const getExportButtons = async () => {
|
||||
const csvLabel = await screen.findByText('CSV')
|
||||
const jsonLabel = await screen.findByText('JSONL')
|
||||
const csvButton = csvLabel.closest('button') as HTMLButtonElement
|
||||
const jsonButton = jsonLabel.closest('button') as HTMLButtonElement
|
||||
expect(csvButton).toBeTruthy()
|
||||
expect(jsonButton).toBeTruthy()
|
||||
return {
|
||||
csvButton,
|
||||
jsonButton,
|
||||
}
|
||||
}
|
||||
|
||||
const clickOperationAction = async (
|
||||
user: ReturnType<typeof userEvent.setup>,
|
||||
translationKey: string,
|
||||
) => {
|
||||
const label = await screen.findByText(translationKey)
|
||||
const button = label.closest('button') as HTMLButtonElement
|
||||
expect(button).toBeTruthy()
|
||||
await user.click(button)
|
||||
}
|
||||
|
||||
const mockAnnotations: AnnotationItemBasic[] = [
|
||||
{
|
||||
question: 'Question 1',
|
||||
answer: 'Answer 1',
|
||||
},
|
||||
]
|
||||
|
||||
const mockedFetchAnnotations = jest.mocked(fetchExportAnnotationList)
|
||||
const mockedClearAllAnnotations = jest.mocked(clearAllAnnotations)
|
||||
|
||||
describe('HeaderOptions', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
jest.useRealTimers()
|
||||
mockCSVDownloader.mockClear()
|
||||
lastCSVDownloaderProps = undefined
|
||||
mockedFetchAnnotations.mockResolvedValue({ data: [] })
|
||||
})
|
||||
|
||||
it('should fetch annotations on mount and render enabled export actions when data exist', async () => {
|
||||
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
|
||||
const user = userEvent.setup()
|
||||
renderComponent()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id')
|
||||
})
|
||||
|
||||
await expandExportMenu(user)
|
||||
|
||||
const { csvButton, jsonButton } = await getExportButtons()
|
||||
|
||||
expect(csvButton).not.toBeDisabled()
|
||||
expect(jsonButton).not.toBeDisabled()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(lastCSVDownloaderProps).toMatchObject({
|
||||
bom: true,
|
||||
filename: 'annotations-en-US',
|
||||
type: 'link',
|
||||
data: [
|
||||
['Question', 'Answer'],
|
||||
['Question 1', 'Answer 1'],
|
||||
],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should disable export actions when there are no annotations', async () => {
|
||||
const user = userEvent.setup()
|
||||
renderComponent()
|
||||
|
||||
await expandExportMenu(user)
|
||||
|
||||
const { csvButton, jsonButton } = await getExportButtons()
|
||||
|
||||
expect(csvButton).toBeDisabled()
|
||||
expect(jsonButton).toBeDisabled()
|
||||
|
||||
expect(lastCSVDownloaderProps).toMatchObject({
|
||||
data: [['Question', 'Answer']],
|
||||
})
|
||||
})
|
||||
|
||||
it('should open the add annotation modal and forward the onAdd callback', async () => {
|
||||
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
|
||||
const user = userEvent.setup()
|
||||
const onAdd = jest.fn().mockResolvedValue(undefined)
|
||||
renderComponent({ onAdd })
|
||||
|
||||
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled())
|
||||
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }),
|
||||
)
|
||||
|
||||
await screen.findByText('appAnnotation.addModal.title')
|
||||
const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')
|
||||
const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')
|
||||
|
||||
await user.type(questionInput, 'Integration question')
|
||||
await user.type(answerInput, 'Integration answer')
|
||||
await user.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onAdd).toHaveBeenCalledWith({
|
||||
question: 'Integration question',
|
||||
answer: 'Integration answer',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should allow bulk import through the batch modal', async () => {
|
||||
const user = userEvent.setup()
|
||||
const onAdded = jest.fn()
|
||||
renderComponent({ onAdded })
|
||||
|
||||
await openOperationsPopover(user)
|
||||
await clickOperationAction(user, 'appAnnotation.table.header.bulkImport')
|
||||
|
||||
expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument()
|
||||
await user.click(
|
||||
screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }),
|
||||
)
|
||||
expect(onAdded).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should trigger JSONL download with locale-specific filename', async () => {
|
||||
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
|
||||
const user = userEvent.setup()
|
||||
const originalCreateElement = document.createElement.bind(document)
|
||||
const anchor = originalCreateElement('a') as HTMLAnchorElement
|
||||
const clickSpy = jest.spyOn(anchor, 'click').mockImplementation(jest.fn())
|
||||
const createElementSpy = jest
|
||||
.spyOn(document, 'createElement')
|
||||
.mockImplementation((tagName: Parameters<Document['createElement']>[0]) => {
|
||||
if (tagName === 'a')
|
||||
return anchor
|
||||
return originalCreateElement(tagName)
|
||||
})
|
||||
const objectURLSpy = jest
|
||||
.spyOn(URL, 'createObjectURL')
|
||||
.mockReturnValue('blob://mock-url')
|
||||
const revokeSpy = jest.spyOn(URL, 'revokeObjectURL').mockImplementation(jest.fn())
|
||||
|
||||
renderComponent({}, LanguagesSupported[1] as string)
|
||||
|
||||
await expandExportMenu(user)
|
||||
|
||||
await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled())
|
||||
|
||||
const { jsonButton } = await getExportButtons()
|
||||
await user.click(jsonButton)
|
||||
|
||||
expect(createElementSpy).toHaveBeenCalled()
|
||||
expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`)
|
||||
expect(clickSpy).toHaveBeenCalled()
|
||||
expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url')
|
||||
|
||||
const blobArg = objectURLSpy.mock.calls[0][0] as Blob
|
||||
await expect(blobArg.text()).resolves.toContain('"Question 1"')
|
||||
|
||||
clickSpy.mockRestore()
|
||||
createElementSpy.mockRestore()
|
||||
objectURLSpy.mockRestore()
|
||||
revokeSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should clear all annotations when confirmation succeeds', async () => {
|
||||
mockedClearAllAnnotations.mockResolvedValue(undefined)
|
||||
const user = userEvent.setup()
|
||||
const onAdded = jest.fn()
|
||||
renderComponent({ onAdded })
|
||||
|
||||
await openOperationsPopover(user)
|
||||
await clickOperationAction(user, 'appAnnotation.table.header.clearAll')
|
||||
|
||||
await screen.findByText('appAnnotation.table.header.clearAllConfirm')
|
||||
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
|
||||
await user.click(confirmButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id')
|
||||
expect(onAdded).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle clear all failures gracefully', async () => {
|
||||
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn())
|
||||
mockedClearAllAnnotations.mockRejectedValue(new Error('network'))
|
||||
const user = userEvent.setup()
|
||||
const onAdded = jest.fn()
|
||||
renderComponent({ onAdded })
|
||||
|
||||
await openOperationsPopover(user)
|
||||
await clickOperationAction(user, 'appAnnotation.table.header.clearAll')
|
||||
await screen.findByText('appAnnotation.table.header.clearAllConfirm')
|
||||
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
|
||||
await user.click(confirmButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockedClearAllAnnotations).toHaveBeenCalled()
|
||||
expect(onAdded).not.toHaveBeenCalled()
|
||||
expect(consoleSpy).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should refetch annotations when controlUpdateList changes', async () => {
|
||||
const view = renderComponent({ controlUpdateList: 0 })
|
||||
|
||||
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1))
|
||||
|
||||
view.rerender(
|
||||
<I18NContext.Provider
|
||||
value={{
|
||||
locale: LanguagesSupported[0] as string,
|
||||
i18n: {},
|
||||
setLocaleOnClient: jest.fn(),
|
||||
}}
|
||||
>
|
||||
<HeaderOptions
|
||||
appId="test-app-id"
|
||||
onAdd={jest.fn()}
|
||||
onAdded={jest.fn()}
|
||||
controlUpdateList={1}
|
||||
/>
|
||||
</I18NContext.Provider>,
|
||||
)
|
||||
|
||||
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2))
|
||||
})
|
||||
})
|
||||
@@ -17,7 +17,7 @@ import Button from '../../../base/button'
|
||||
import AddAnnotationModal from '../add-annotation-modal'
|
||||
import type { AnnotationItemBasic } from '../type'
|
||||
import BatchAddModal from '../batch-add-annotation-modal'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import CustomPopover from '@/app/components/base/popover'
|
||||
import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files'
|
||||
import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import React from 'react'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Annotation from './index'
|
||||
import type { AnnotationItem } from './type'
|
||||
import { JobStatus } from './type'
|
||||
import { type App, AppModeEnum } from '@/types/app'
|
||||
import {
|
||||
addAnnotation,
|
||||
delAnnotation,
|
||||
delAnnotations,
|
||||
fetchAnnotationConfig,
|
||||
fetchAnnotationList,
|
||||
queryAnnotationJobStatus,
|
||||
} from '@/service/annotation'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
jest.mock('@/app/components/base/toast', () => ({
|
||||
__esModule: true,
|
||||
default: { notify: jest.fn() },
|
||||
}))
|
||||
|
||||
jest.mock('ahooks', () => ({
|
||||
useDebounce: (value: any) => value,
|
||||
}))
|
||||
|
||||
jest.mock('@/service/annotation', () => ({
|
||||
addAnnotation: jest.fn(),
|
||||
delAnnotation: jest.fn(),
|
||||
delAnnotations: jest.fn(),
|
||||
fetchAnnotationConfig: jest.fn(),
|
||||
editAnnotation: jest.fn(),
|
||||
fetchAnnotationList: jest.fn(),
|
||||
queryAnnotationJobStatus: jest.fn(),
|
||||
updateAnnotationScore: jest.fn(),
|
||||
updateAnnotationStatus: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('./filter', () => ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="filter">{children}</div>
|
||||
))
|
||||
|
||||
jest.mock('./empty-element', () => () => <div data-testid="empty-element" />)
|
||||
|
||||
jest.mock('./header-opts', () => (props: any) => (
|
||||
<div data-testid="header-opts">
|
||||
<button data-testid="trigger-add" onClick={() => props.onAdd({ question: 'new question', answer: 'new answer' })}>
|
||||
add
|
||||
</button>
|
||||
</div>
|
||||
))
|
||||
|
||||
let latestListProps: any
|
||||
|
||||
jest.mock('./list', () => (props: any) => {
|
||||
latestListProps = props
|
||||
if (!props.list.length)
|
||||
return <div data-testid="list-empty" />
|
||||
return (
|
||||
<div data-testid="list">
|
||||
<button data-testid="list-view" onClick={() => props.onView(props.list[0])}>view</button>
|
||||
<button data-testid="list-remove" onClick={() => props.onRemove(props.list[0].id)}>remove</button>
|
||||
<button data-testid="list-batch-delete" onClick={() => props.onBatchDelete()}>batch-delete</button>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
jest.mock('./view-annotation-modal', () => (props: any) => {
|
||||
if (!props.isShow)
|
||||
return null
|
||||
return (
|
||||
<div data-testid="view-modal">
|
||||
<div>{props.item.question}</div>
|
||||
<button data-testid="view-modal-remove" onClick={props.onRemove}>remove</button>
|
||||
<button data-testid="view-modal-close" onClick={props.onHide}>close</button>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
|
||||
jest.mock('@/app/components/base/pagination', () => () => <div data-testid="pagination" />)
|
||||
jest.mock('@/app/components/base/loading', () => () => <div data-testid="loading" />)
|
||||
jest.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => (props: any) => props.isShow ? <div data-testid="config-modal" /> : null)
|
||||
jest.mock('@/app/components/billing/annotation-full/modal', () => (props: any) => props.show ? <div data-testid="annotation-full-modal" /> : null)
|
||||
|
||||
const mockNotify = Toast.notify as jest.Mock
|
||||
const addAnnotationMock = addAnnotation as jest.Mock
|
||||
const delAnnotationMock = delAnnotation as jest.Mock
|
||||
const delAnnotationsMock = delAnnotations as jest.Mock
|
||||
const fetchAnnotationConfigMock = fetchAnnotationConfig as jest.Mock
|
||||
const fetchAnnotationListMock = fetchAnnotationList as jest.Mock
|
||||
const queryAnnotationJobStatusMock = queryAnnotationJobStatus as jest.Mock
|
||||
const useProviderContextMock = useProviderContext as jest.Mock
|
||||
|
||||
const appDetail = {
|
||||
id: 'app-id',
|
||||
mode: AppModeEnum.CHAT,
|
||||
} as App
|
||||
|
||||
const createAnnotation = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
|
||||
id: overrides.id ?? 'annotation-1',
|
||||
question: overrides.question ?? 'Question 1',
|
||||
answer: overrides.answer ?? 'Answer 1',
|
||||
created_at: overrides.created_at ?? 1700000000,
|
||||
hit_count: overrides.hit_count ?? 0,
|
||||
})
|
||||
|
||||
const renderComponent = () => render(<Annotation appDetail={appDetail} />)
|
||||
|
||||
describe('Annotation', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
latestListProps = undefined
|
||||
fetchAnnotationConfigMock.mockResolvedValue({
|
||||
id: 'config-id',
|
||||
enabled: false,
|
||||
embedding_model: {
|
||||
embedding_model_name: 'model',
|
||||
embedding_provider_name: 'provider',
|
||||
},
|
||||
score_threshold: 0.5,
|
||||
})
|
||||
fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 })
|
||||
queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed })
|
||||
useProviderContextMock.mockReturnValue({
|
||||
plan: {
|
||||
usage: { annotatedResponse: 0 },
|
||||
total: { annotatedResponse: 10 },
|
||||
},
|
||||
enableBilling: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should render empty element when no annotations are returned', async () => {
|
||||
renderComponent()
|
||||
|
||||
expect(await screen.findByTestId('empty-element')).toBeInTheDocument()
|
||||
expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({
|
||||
page: 1,
|
||||
keyword: '',
|
||||
}))
|
||||
})
|
||||
|
||||
it('should handle annotation creation and refresh list data', async () => {
|
||||
const annotation = createAnnotation()
|
||||
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
|
||||
addAnnotationMock.mockResolvedValue(undefined)
|
||||
|
||||
renderComponent()
|
||||
|
||||
await screen.findByTestId('list')
|
||||
fireEvent.click(screen.getByTestId('trigger-add'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' })
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'common.api.actionSuccess',
|
||||
type: 'success',
|
||||
}))
|
||||
})
|
||||
expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
it('should support viewing items and running batch deletion success flow', async () => {
|
||||
const annotation = createAnnotation()
|
||||
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
|
||||
delAnnotationsMock.mockResolvedValue(undefined)
|
||||
delAnnotationMock.mockResolvedValue(undefined)
|
||||
|
||||
renderComponent()
|
||||
await screen.findByTestId('list')
|
||||
|
||||
await act(async () => {
|
||||
latestListProps.onSelectedIdsChange([annotation.id])
|
||||
})
|
||||
await waitFor(() => {
|
||||
expect(latestListProps.selectedIds).toEqual([annotation.id])
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await latestListProps.onBatchDelete()
|
||||
})
|
||||
await waitFor(() => {
|
||||
expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id])
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
type: 'success',
|
||||
}))
|
||||
expect(latestListProps.selectedIds).toEqual([])
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByTestId('list-view'))
|
||||
expect(screen.getByTestId('view-modal')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByTestId('view-modal-remove'))
|
||||
})
|
||||
await waitFor(() => {
|
||||
expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id)
|
||||
})
|
||||
})
|
||||
|
||||
it('should show an error notification when batch deletion fails', async () => {
|
||||
const annotation = createAnnotation()
|
||||
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
|
||||
const error = new Error('failed')
|
||||
delAnnotationsMock.mockRejectedValue(error)
|
||||
|
||||
renderComponent()
|
||||
await screen.findByTestId('list')
|
||||
|
||||
await act(async () => {
|
||||
latestListProps.onSelectedIdsChange([annotation.id])
|
||||
})
|
||||
await waitFor(() => {
|
||||
expect(latestListProps.selectedIds).toEqual([annotation.id])
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await latestListProps.onBatchDelete()
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: error.message,
|
||||
})
|
||||
expect(latestListProps.selectedIds).toEqual([annotation.id])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -25,7 +25,7 @@ import { sleep } from '@/utils'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import AnnotationFullModal from '@/app/components/billing/annotation-full/modal'
|
||||
import { type App, AppModeEnum } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { delAnnotations } from '@/service/annotation'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen, within } from '@testing-library/react'
|
||||
import List from './list'
|
||||
import type { AnnotationItem } from './type'
|
||||
|
||||
const mockFormatTime = jest.fn(() => 'formatted-time')
|
||||
|
||||
jest.mock('@/hooks/use-timestamp', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
formatTime: mockFormatTime,
|
||||
}),
|
||||
}))
|
||||
|
||||
const createAnnotation = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
|
||||
id: overrides.id ?? 'annotation-id',
|
||||
question: overrides.question ?? 'question 1',
|
||||
answer: overrides.answer ?? 'answer 1',
|
||||
created_at: overrides.created_at ?? 1700000000,
|
||||
hit_count: overrides.hit_count ?? 2,
|
||||
})
|
||||
|
||||
const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]')
|
||||
|
||||
describe('List', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should render annotation rows and call onView when clicking a row', () => {
|
||||
const item = createAnnotation()
|
||||
const onView = jest.fn()
|
||||
|
||||
render(
|
||||
<List
|
||||
list={[item]}
|
||||
onView={onView}
|
||||
onRemove={jest.fn()}
|
||||
selectedIds={[]}
|
||||
onSelectedIdsChange={jest.fn()}
|
||||
onBatchDelete={jest.fn()}
|
||||
onCancel={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText(item.question))
|
||||
|
||||
expect(onView).toHaveBeenCalledWith(item)
|
||||
expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat')
|
||||
})
|
||||
|
||||
it('should toggle single and bulk selection states', () => {
|
||||
const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })]
|
||||
const onSelectedIdsChange = jest.fn()
|
||||
const { container, rerender } = render(
|
||||
<List
|
||||
list={list}
|
||||
onView={jest.fn()}
|
||||
onRemove={jest.fn()}
|
||||
selectedIds={[]}
|
||||
onSelectedIdsChange={onSelectedIdsChange}
|
||||
onBatchDelete={jest.fn()}
|
||||
onCancel={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
const checkboxes = getCheckboxes(container)
|
||||
fireEvent.click(checkboxes[1])
|
||||
expect(onSelectedIdsChange).toHaveBeenCalledWith(['a'])
|
||||
|
||||
rerender(
|
||||
<List
|
||||
list={list}
|
||||
onView={jest.fn()}
|
||||
onRemove={jest.fn()}
|
||||
selectedIds={['a']}
|
||||
onSelectedIdsChange={onSelectedIdsChange}
|
||||
onBatchDelete={jest.fn()}
|
||||
onCancel={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
const updatedCheckboxes = getCheckboxes(container)
|
||||
fireEvent.click(updatedCheckboxes[1])
|
||||
expect(onSelectedIdsChange).toHaveBeenCalledWith([])
|
||||
|
||||
fireEvent.click(updatedCheckboxes[0])
|
||||
expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b'])
|
||||
})
|
||||
|
||||
it('should confirm before removing an annotation and expose batch actions', async () => {
|
||||
const item = createAnnotation({ id: 'to-delete', question: 'Delete me' })
|
||||
const onRemove = jest.fn()
|
||||
render(
|
||||
<List
|
||||
list={[item]}
|
||||
onView={jest.fn()}
|
||||
onRemove={onRemove}
|
||||
selectedIds={[item.id]}
|
||||
onSelectedIdsChange={jest.fn()}
|
||||
onBatchDelete={jest.fn()}
|
||||
onCancel={jest.fn()}
|
||||
/>,
|
||||
)
|
||||
|
||||
const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement
|
||||
const actionButtons = within(row).getAllByRole('button')
|
||||
fireEvent.click(actionButtons[1])
|
||||
|
||||
expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
|
||||
const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' })
|
||||
fireEvent.click(confirmButton)
|
||||
expect(onRemove).toHaveBeenCalledWith(item.id)
|
||||
|
||||
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,7 @@ import type { AnnotationItem } from './type'
|
||||
import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal'
|
||||
import ActionButton from '@/app/components/base/action-button'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Checkbox from '@/app/components/base/checkbox'
|
||||
import BatchAction from './batch-action'
|
||||
|
||||
|
||||
@@ -12,12 +12,6 @@ export type AnnotationItem = {
|
||||
hit_count: number
|
||||
}
|
||||
|
||||
export type AnnotationCreateResponse = AnnotationItem & {
|
||||
account?: {
|
||||
name?: string
|
||||
}
|
||||
}
|
||||
|
||||
export type HitHistoryItem = {
|
||||
id: string
|
||||
question: string
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
import React from 'react'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import ViewAnnotationModal from './index'
|
||||
import type { AnnotationItem, HitHistoryItem } from '../type'
|
||||
import { fetchHitHistoryList } from '@/service/annotation'
|
||||
|
||||
const mockFormatTime = jest.fn(() => 'formatted-time')
|
||||
|
||||
jest.mock('@/hooks/use-timestamp', () => ({
|
||||
__esModule: true,
|
||||
default: () => ({
|
||||
formatTime: mockFormatTime,
|
||||
}),
|
||||
}))
|
||||
|
||||
jest.mock('@/service/annotation', () => ({
|
||||
fetchHitHistoryList: jest.fn(),
|
||||
}))
|
||||
|
||||
jest.mock('../edit-annotation-modal/edit-item', () => {
|
||||
const EditItemType = {
|
||||
Query: 'query',
|
||||
Answer: 'answer',
|
||||
}
|
||||
return {
|
||||
__esModule: true,
|
||||
default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => (
|
||||
<div>
|
||||
<div data-testid={`content-${type}`}>{content}</div>
|
||||
<button data-testid={`edit-${type}`} onClick={() => onSave(`${type}-updated`)}>edit-{type}</button>
|
||||
</div>
|
||||
),
|
||||
EditItemType,
|
||||
}
|
||||
})
|
||||
|
||||
const fetchHitHistoryListMock = fetchHitHistoryList as jest.Mock
|
||||
|
||||
const createAnnotationItem = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
|
||||
id: overrides.id ?? 'annotation-id',
|
||||
question: overrides.question ?? 'question',
|
||||
answer: overrides.answer ?? 'answer',
|
||||
created_at: overrides.created_at ?? 1700000000,
|
||||
hit_count: overrides.hit_count ?? 0,
|
||||
})
|
||||
|
||||
const createHitHistoryItem = (overrides: Partial<HitHistoryItem> = {}): HitHistoryItem => ({
|
||||
id: overrides.id ?? 'hit-id',
|
||||
question: overrides.question ?? 'query',
|
||||
match: overrides.match ?? 'match',
|
||||
response: overrides.response ?? 'response',
|
||||
source: overrides.source ?? 'source',
|
||||
score: overrides.score ?? 0.42,
|
||||
created_at: overrides.created_at ?? 1700000000,
|
||||
})
|
||||
|
||||
const renderComponent = (props?: Partial<React.ComponentProps<typeof ViewAnnotationModal>>) => {
|
||||
const item = createAnnotationItem()
|
||||
const mergedProps: React.ComponentProps<typeof ViewAnnotationModal> = {
|
||||
appId: 'app-id',
|
||||
isShow: true,
|
||||
onHide: jest.fn(),
|
||||
item,
|
||||
onSave: jest.fn().mockResolvedValue(undefined),
|
||||
onRemove: jest.fn().mockResolvedValue(undefined),
|
||||
...props,
|
||||
}
|
||||
return {
|
||||
...render(<ViewAnnotationModal {...mergedProps} />),
|
||||
props: mergedProps,
|
||||
}
|
||||
}
|
||||
|
||||
describe('ViewAnnotationModal', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 })
|
||||
})
|
||||
|
||||
it('should render annotation tab and allow saving updated query', async () => {
|
||||
// Arrange
|
||||
const { props } = renderComponent()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchHitHistoryListMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByTestId('edit-query'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer)
|
||||
})
|
||||
})
|
||||
|
||||
it('should render annotation tab and allow saving updated answer', async () => {
|
||||
// Arrange
|
||||
const { props } = renderComponent()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchHitHistoryListMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByTestId('edit-answer'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated')
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
it('should switch to hit history tab and show no data message', async () => {
|
||||
// Arrange
|
||||
const { props } = renderComponent()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchHitHistoryListMock).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Act
|
||||
fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory'))
|
||||
|
||||
// Assert
|
||||
expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument()
|
||||
expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat')
|
||||
})
|
||||
|
||||
it('should render hit history entries with pagination badge when data exists', async () => {
|
||||
const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })]
|
||||
fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 })
|
||||
|
||||
renderComponent()
|
||||
|
||||
fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory'))
|
||||
|
||||
expect(await screen.findByText('user input')).toBeInTheDocument()
|
||||
expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument()
|
||||
expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat')
|
||||
})
|
||||
|
||||
it('should confirm before removing the annotation and hide on success', async () => {
|
||||
const { props } = renderComponent()
|
||||
|
||||
fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
|
||||
expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
|
||||
|
||||
const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' })
|
||||
fireEvent.click(confirmButton)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(props.onRemove).toHaveBeenCalledTimes(1)
|
||||
expect(props.onHide).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain'
|
||||
import { fetchHitHistoryList } from '@/service/annotation'
|
||||
import { APP_PAGE_LIMIT } from '@/config'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
appId: string
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
import { Dialog, Transition } from '@headlessui/react'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type DialogProps = {
|
||||
className?: string
|
||||
|
||||
@@ -181,7 +181,7 @@ describe('AccessControlItem', () => {
|
||||
expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION)
|
||||
})
|
||||
|
||||
it('should keep current menu when clicking the selected access type', () => {
|
||||
it('should render selected styles when the current menu matches the type', () => {
|
||||
useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION })
|
||||
render(
|
||||
<AccessControlItem type={AccessMode.ORGANIZATION}>
|
||||
@@ -190,9 +190,8 @@ describe('AccessControlItem', () => {
|
||||
)
|
||||
|
||||
const option = screen.getByText('Organization Only').parentElement as HTMLElement
|
||||
fireEvent.click(option)
|
||||
|
||||
expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION)
|
||||
expect(option.className).toContain('border-[1.5px]')
|
||||
expect(option.className).not.toContain('cursor-pointer')
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import Input from '../../base/input'
|
||||
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
|
||||
import Loading from '../../base/loading'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import classNames from '@/utils/classnames'
|
||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||
import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control'
|
||||
import { SubjectType } from '@/models/access-control'
|
||||
@@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() {
|
||||
setSelectedGroupsForBreadcrumb([])
|
||||
}, [setSelectedGroupsForBreadcrumb])
|
||||
return <div className='flex h-7 items-center gap-x-0.5 px-2 py-0.5'>
|
||||
<span className={cn('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
|
||||
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
|
||||
{selectedGroupsForBreadcrumb.map((group, index) => {
|
||||
return <div key={index} className='system-xs-regular flex items-center gap-x-0.5 text-text-tertiary'>
|
||||
<span>/</span>
|
||||
@@ -198,7 +198,7 @@ type BaseItemProps = {
|
||||
children: React.ReactNode
|
||||
}
|
||||
function BaseItem({ children, className }: BaseItemProps) {
|
||||
return <div className={cn('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
|
||||
return <div className={classNames('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
|
||||
{children}
|
||||
</div>
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { HTMLProps, PropsWithChildren } from 'react'
|
||||
import { RiArrowRightUpLine } from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import classNames from '@/utils/classnames'
|
||||
|
||||
export type SuggestedActionProps = PropsWithChildren<HTMLProps<HTMLAnchorElement> & {
|
||||
icon?: React.ReactNode
|
||||
@@ -19,9 +19,11 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, .
|
||||
href={disabled ? undefined : link}
|
||||
target='_blank'
|
||||
rel='noreferrer'
|
||||
className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
|
||||
className={classNames(
|
||||
'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
|
||||
disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent',
|
||||
className)}
|
||||
className,
|
||||
)}
|
||||
onClick={handleClick}
|
||||
{...props}
|
||||
>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
export type IFeaturePanelProps = {
|
||||
className?: string
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
RiAddLine,
|
||||
RiEditLine,
|
||||
} from '@remixicon/react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
export type IOperationBtnProps = {
|
||||
|
||||
@@ -14,7 +14,7 @@ import s from './style.module.css'
|
||||
import MessageTypeSelector from './message-type-selector'
|
||||
import ConfirmAddVar from './confirm-add-var'
|
||||
import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { PromptRole, PromptVariable } from '@/models/debug'
|
||||
import {
|
||||
Copy,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useBoolean, useClickAway } from 'ahooks'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { PromptRole } from '@/models/debug'
|
||||
import { ChevronSelectorVertical } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
type Props = {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import React, { useCallback, useEffect, useState } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { useDebounceFn } from 'ahooks'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
|
||||
@@ -7,7 +7,7 @@ import { produce } from 'immer'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import ConfirmAddVar from './confirm-add-var'
|
||||
import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { PromptVariable } from '@/models/debug'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import type { FC } from 'react'
|
||||
import React, { useState } from 'react'
|
||||
import { ChevronDownIcon } from '@heroicons/react/20/solid'
|
||||
import classNames from '@/utils/classnames'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
@@ -9,7 +10,7 @@ import {
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
|
||||
import type { InputVarType } from '@/app/components/workflow/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import cn from '@/utils/classnames'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils'
|
||||
|
||||
@@ -46,7 +47,7 @@ const TypeSelector: FC<Props> = ({
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => !readonly && setOpen(v => !v)} className='w-full'>
|
||||
<div
|
||||
className={cn(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
|
||||
className={classNames(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
|
||||
title={selectedItem?.name}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
@@ -68,7 +69,7 @@ const TypeSelector: FC<Props> = ({
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className='z-[61]'>
|
||||
<div
|
||||
className={cn('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
className={classNames('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
|
||||
>
|
||||
{items.map((item: Item) => (
|
||||
<div
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user