Compare commits

..

2 Commits

Author SHA1 Message Date
autofix-ci[bot]
c8cabc9bdb [autofix.ci] apply automated fixes 2025-12-18 09:23:37 +00:00
hj24
45e2d4627f refactor: clean messages task 2025-12-18 17:20:40 +08:00
982 changed files with 5202 additions and 37222 deletions

2
.github/CODEOWNERS vendored
View File

@@ -122,7 +122,7 @@ api/controllers/console/feature.py @GarfieldDai @GareArc
api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 @MRZHUH
api/migrations/ @snakevash @laipz8200
# Frontend
web/ @iamjoel

View File

@@ -79,7 +79,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Web dependencies
working-directory: ./web

View File

@@ -90,7 +90,7 @@ jobs:
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Web dependencies
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -55,7 +55,7 @@ jobs:
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: env.FILES_CHANGED == 'true'

View File

@@ -13,7 +13,6 @@ jobs:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
@@ -22,7 +21,14 @@ jobs:
with:
persist-credentials: false
- name: Check changed files
id: changed-files
uses: tj-actions/changed-files@v46
with:
files: web/**
- name: Install pnpm
if: steps.changed-files.outputs.any_changed == 'true'
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
@@ -30,355 +36,23 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v4
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Restore Jest cache
uses: actions/cache@v4
with:
path: web/.cache/jest
key: ${{ runner.os }}-jest-${{ hashFiles('web/pnpm-lock.yaml') }}
restore-keys: |
${{ runner.os }}-jest-
cache-dependency-path: ./web/package.json
- name: Install dependencies
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Check i18n types synchronization
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run check:i18n-types
- name: Run tests
run: |
pnpm exec jest \
--ci \
--maxWorkers=100% \
--coverage \
--passWithNoTests
- name: Coverage Summary
if: always()
id: coverage-summary
run: |
set -eo pipefail
COVERAGE_FILE="coverage/coverage-final.json"
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
echo "Coverage data not found. Ensure Jest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
exit 0
fi
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
const fs = require('fs');
const path = require('path');
let libCoverage = null;
try {
libCoverage = require('istanbul-lib-coverage');
} catch (error) {
libCoverage = null;
}
const summaryPath = path.join('coverage', 'coverage-summary.json');
const finalPath = path.join('coverage', 'coverage-final.json');
const hasSummary = fs.existsSync(summaryPath);
const hasFinal = fs.existsSync(finalPath);
if (!hasSummary && !hasFinal) {
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('No coverage data found.');
process.exit(0);
}
const summary = hasSummary
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
: null;
const coverage = hasFinal
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
: null;
const getLineCoverageFromStatements = (statementMap, statementHits) => {
const lineHits = {};
if (!statementMap || !statementHits) {
return lineHits;
}
Object.entries(statementMap).forEach(([key, statement]) => {
const line = statement?.start?.line;
if (!line) {
return;
}
const hits = statementHits[key] ?? 0;
const previous = lineHits[line];
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
});
return lineHits;
};
const getFileCoverage = (entry) => (
libCoverage ? libCoverage.createFileCoverage(entry) : null
);
const getLineHits = (entry, fileCoverage) => {
const lineHits = entry.l ?? {};
if (Object.keys(lineHits).length > 0) {
return lineHits;
}
if (fileCoverage) {
return fileCoverage.getLineCoverage();
}
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
};
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
if (lineHits && Object.keys(lineHits).length > 0) {
return Object.entries(lineHits)
.filter(([, count]) => count === 0)
.map(([line]) => Number(line))
.sort((a, b) => a - b);
}
if (fileCoverage) {
return fileCoverage.getUncoveredLines();
}
return [];
};
const totals = {
lines: { covered: 0, total: 0 },
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
};
const fileSummaries = [];
if (summary) {
const totalEntry = summary.total ?? {};
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
if (totalEntry[key]) {
totals[key].covered = totalEntry[key].covered ?? 0;
totals[key].total = totalEntry[key].total ?? 0;
}
});
Object.entries(summary)
.filter(([file]) => file !== 'total')
.forEach(([file, data]) => {
fileSummaries.push({
file,
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
lines: {
covered: data.lines?.covered ?? 0,
total: data.lines?.total ?? 0,
},
});
});
} else if (coverage) {
Object.entries(coverage).forEach(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
totals.lines.total += lineTotal;
totals.lines.covered += lineCovered;
totals.statements.total += statementTotal;
totals.statements.covered += statementCovered;
totals.branches.total += branchTotal;
totals.branches.covered += branchCovered;
totals.functions.total += functionTotal;
totals.functions.covered += functionCovered;
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
fileSummaries.push({
file,
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
lines: {
covered: lineCovered || statementCovered,
total: lineTotal || statementTotal,
},
});
});
}
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
console.log('### Test Coverage Summary :test_tube:');
console.log('');
console.log('| Metric | Coverage | Covered / Total |');
console.log('|--------|----------|-----------------|');
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
console.log('');
console.log('<details><summary>File coverage (lowest lines first)</summary>');
console.log('');
console.log('```');
fileSummaries
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
.slice(0, 25)
.forEach(({ file, pct, lines }) => {
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
});
console.log('```');
console.log('</details>');
if (coverage) {
const pctValue = (covered, tot) => {
if (tot === 0) {
return '0';
}
return ((covered / tot) * 100)
.toFixed(2)
.replace(/\.?0+$/, '');
};
const formatLineRanges = (lines) => {
if (lines.length === 0) {
return '';
}
const ranges = [];
let start = lines[0];
let end = lines[0];
for (let i = 1; i < lines.length; i += 1) {
const current = lines[i];
if (current === end + 1) {
end = current;
continue;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
start = current;
end = current;
}
ranges.push(start === end ? `${start}` : `${start}-${end}`);
return ranges.join(',');
};
const tableTotals = {
statements: { covered: 0, total: 0 },
branches: { covered: 0, total: 0 },
functions: { covered: 0, total: 0 },
lines: { covered: 0, total: 0 },
};
const tableRows = Object.entries(coverage)
.map(([file, entry]) => {
const fileCoverage = getFileCoverage(entry);
const lineHits = getLineHits(entry, fileCoverage);
const statementHits = entry.s ?? {};
const branchHits = entry.b ?? {};
const functionHits = entry.f ?? {};
const lineTotal = Object.keys(lineHits).length;
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
const statementTotal = Object.keys(statementHits).length;
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
const branchCovered = Object.values(branchHits).reduce(
(acc, branches) => acc + branches.filter((n) => n > 0).length,
0,
);
const functionTotal = Object.keys(functionHits).length;
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
tableTotals.lines.total += lineTotal;
tableTotals.lines.covered += lineCovered;
tableTotals.statements.total += statementTotal;
tableTotals.statements.covered += statementCovered;
tableTotals.branches.total += branchTotal;
tableTotals.branches.covered += branchCovered;
tableTotals.functions.total += functionTotal;
tableTotals.functions.covered += functionCovered;
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
const filePath = entry.path ?? file;
const relativePath = path.isAbsolute(filePath)
? path.relative(process.cwd(), filePath)
: filePath;
return {
file: relativePath || file,
statements: pctValue(statementCovered, statementTotal),
branches: pctValue(branchCovered, branchTotal),
functions: pctValue(functionCovered, functionTotal),
lines: pctValue(lineCovered, lineTotal),
uncovered: formatLineRanges(uncoveredLines),
};
})
.sort((a, b) => a.file.localeCompare(b.file));
const columns = [
{ key: 'file', header: 'File', align: 'left' },
{ key: 'statements', header: '% Stmts', align: 'right' },
{ key: 'branches', header: '% Branch', align: 'right' },
{ key: 'functions', header: '% Funcs', align: 'right' },
{ key: 'lines', header: '% Lines', align: 'right' },
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
];
const allFilesRow = {
file: 'All files',
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
uncovered: '',
};
const rowsForOutput = [allFilesRow, ...tableRows];
const formatRow = (row) => `| ${columns
.map(({ key }) => String(row[key] ?? ''))
.join(' | ')} |`;
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
const dividerRow = `| ${columns
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
.join(' | ')} |`;
console.log('');
console.log('<details><summary>Jest coverage table</summary>');
console.log('');
console.log(headerRow);
console.log(dividerRow);
rowsForOutput.forEach((row) => console.log(formatRow(row)));
console.log('</details>');
}
NODE
- name: Upload Coverage Artifact
if: steps.coverage-summary.outputs.has_coverage == 'true'
uses: actions/upload-artifact@v4
with:
name: web-coverage-report
path: web/coverage
retention-days: 30
if-no-files-found: error
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm test

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@@ -45,6 +47,7 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -1900,3 +1903,76 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-sandbox-messages", help="Clean expired sandbox messages.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration.",
)
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-after.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleteing")
def clean_expired_sandbox_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for sandbox tenants.
"""
if not dify_config.BILLING_ENABLED:
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
return
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Messages found: {stats['total_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("Sandbox messages cleanup completed.", fg="green"))

View File

@@ -218,7 +218,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_TIMEOUT: PositiveFloat | None = Field(
description="Timeout in seconds for requests to the plugin daemon (set to None to disable)",
default=600.0,
default=300.0,
)
INNER_API_KEY_FOR_PLUGIN: str = Field(description="Inner api key for plugin", default="inner-api-key")

View File

@@ -338,7 +338,9 @@ class CompletionConversationApi(Resource):
current_user, _ = current_account_with_tenant()
args = CompletionConversationQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
query = sa.select(Conversation).where(
Conversation.app_id == app_model.id, Conversation.mode == "completion", Conversation.is_deleted.is_(False)
)
if args.keyword:
query = query.join(Message, Message.conversation_id == Conversation.id).where(
@@ -450,7 +452,7 @@ class ChatConversationApi(Resource):
.subquery()
)
query = sa.select(Conversation).where(Conversation.app_id == app_model.id)
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))
if args.keyword:
keyword_filter = f"%{args.keyword}%"

View File

@@ -146,7 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@@ -40,7 +40,7 @@ from .. import console_ns
logger = logging.getLogger(__name__)
class CompletionMessageExplorePayload(BaseModel):
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
@@ -71,7 +71,7 @@ class ChatMessagePayload(BaseModel):
raise ValueError("must be a valid UUID") from exc
register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessagePayload)
register_schema_models(console_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@@ -80,13 +80,13 @@ register_schema_models(console_ns, CompletionMessageExplorePayload, ChatMessageP
endpoint="installed_app_completion",
)
class CompletionApi(InstalledAppResource):
@console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__])
@console_ns.expect(console_ns.models[CompletionMessagePayload.__name__])
def post(self, installed_app):
app_model = installed_app.app
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
payload = CompletionMessageExplorePayload.model_validate(console_ns.payload or {})
payload = CompletionMessagePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
streaming = payload.response_mode == "streaming"

View File

@@ -1,40 +1,31 @@
from typing import Literal
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from fields.tag_fields import dataset_tag_fields
from libs.login import current_account_with_tenant, login_required
from models.model import Tag
from services.tag_service import TagService
class TagBasePayload(BaseModel):
name: str = Field(description="Tag name", min_length=1, max_length=50)
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
def _validate_name(name):
if not name or len(name) < 1 or len(name) > 50:
raise ValueError("Name must be between 1 to 50 characters.")
return name
class TagBindingPayload(BaseModel):
tag_ids: list[str] = Field(description="Tag IDs to bind")
target_id: str = Field(description="Target ID to bind tags to")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagBindingRemovePayload(BaseModel):
tag_id: str = Field(description="Tag ID to remove")
target_id: str = Field(description="Target ID to unbind tag from")
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
register_schema_models(
console_ns,
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
parser_tags = (
reqparse.RequestParser()
.add_argument(
"name",
nullable=False,
required=True,
help="Name must be between 1 to 50 characters.",
type=_validate_name,
)
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@@ -52,7 +43,7 @@ class TagListApi(Resource):
return tags, 200
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@console_ns.expect(parser_tags)
@setup_required
@login_required
@account_initialization_required
@@ -62,17 +53,22 @@ class TagListApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.save_tags(payload.model_dump())
args = parser_tags.parse_args()
tag = TagService.save_tags(args)
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
return response, 200
parser_tag_id = reqparse.RequestParser().add_argument(
"name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name
)
@console_ns.route("/tags/<uuid:tag_id>")
class TagUpdateDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBasePayload.__name__])
@console_ns.expect(parser_tag_id)
@setup_required
@login_required
@account_initialization_required
@@ -83,8 +79,8 @@ class TagUpdateDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBasePayload.model_validate(console_ns.payload or {})
tag = TagService.update_tags(payload.model_dump(), tag_id)
args = parser_tag_id.parse_args()
tag = TagService.update_tags(args, tag_id)
binding_count = TagService.get_tag_binding_count(tag_id)
@@ -104,9 +100,17 @@ class TagUpdateDeleteApi(Resource):
return 204
parser_create = (
reqparse.RequestParser()
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
.add_argument("target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/create")
class TagBindingCreateApi(Resource):
@console_ns.expect(console_ns.models[TagBindingPayload.__name__])
@console_ns.expect(parser_create)
@setup_required
@login_required
@account_initialization_required
@@ -116,15 +120,23 @@ class TagBindingCreateApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(payload.model_dump())
args = parser_create.parse_args()
TagService.save_tag_binding(args)
return {"result": "success"}, 200
parser_remove = (
reqparse.RequestParser()
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
.add_argument("type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type.")
)
@console_ns.route("/tag-bindings/remove")
class TagBindingDeleteApi(Resource):
@console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__])
@console_ns.expect(parser_remove)
@setup_required
@login_required
@account_initialization_required
@@ -134,7 +146,7 @@ class TagBindingDeleteApi(Resource):
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(payload.model_dump())
args = parser_remove.parse_args()
TagService.delete_tag_binding(args)
return {"result": "success"}, 200

View File

@@ -49,7 +49,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: RetrievalModel | None = None
partial_member_list: list[dict[str, str]] | None = None
partial_member_list: list[str] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
external_knowledge_api_id: str | None = None

View File

@@ -1,8 +1,7 @@
import logging
from flask import request
from flask_restx import fields, marshal_with
from pydantic import BaseModel, field_validator
from flask_restx import fields, marshal_with, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -21,7 +20,6 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
from libs.helper import uuid_value
from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
@@ -31,25 +29,6 @@ from services.errors.audio import (
UnsupportedAudioTypeServiceError,
)
from ..common.schema import register_schema_models
class TextToAudioPayload(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, TextToAudioPayload)
logger = logging.getLogger(__name__)
@@ -109,7 +88,6 @@ class AudioApi(WebApiResource):
@web_ns.route("/text-to-audio")
class TextApi(WebApiResource):
@web_ns.expect(web_ns.models[TextToAudioPayload.__name__])
@web_ns.doc("Text to Audio")
@web_ns.doc(description="Convert text to audio using text-to-speech service.")
@web_ns.doc(
@@ -124,11 +102,18 @@ class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
"""Convert text to audio"""
try:
payload = TextToAudioPayload.model_validate(web_ns.payload or {})
parser = (
reqparse.RequestParser()
.add_argument("message_id", type=str, required=False, location="json")
.add_argument("voice", type=str, location="json")
.add_argument("text", type=str, location="json")
.add_argument("streaming", type=bool, location="json")
)
args = parser.parse_args()
message_id = payload.message_id
text = payload.text
voice = payload.voice
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
response = AudioService.transcript_tts(
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)

View File

@@ -1,11 +1,9 @@
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from flask_restx import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
from controllers.common.schema import register_schema_models
from controllers.web import web_ns
from controllers.web.error import (
AppUnavailableError,
@@ -36,44 +34,25 @@ from services.errors.llm import InvokeRateLimitError
logger = logging.getLogger(__name__)
class CompletionMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the completion")
query: str = Field(default="", description="Query text for completion")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
retriever_from: str = Field(default="web_app", description="Source of retriever")
class ChatMessagePayload(BaseModel):
inputs: dict[str, Any] = Field(description="Input variables for the chat")
query: str = Field(description="User query/message")
files: list[dict[str, Any]] | None = Field(default=None, description="Files to be processed")
response_mode: Literal["blocking", "streaming"] | None = Field(
default=None, description="Response mode: blocking or streaming"
)
conversation_id: str | None = Field(default=None, description="Conversation ID")
parent_message_id: str | None = Field(default=None, description="Parent message ID")
retriever_from: str = Field(default="web_app", description="Source of retriever")
@field_validator("conversation_id", "parent_message_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
register_schema_models(web_ns, CompletionMessagePayload, ChatMessagePayload)
# define completion api for user
@web_ns.route("/completion-messages")
class CompletionApi(WebApiResource):
@web_ns.doc("Create Completion Message")
@web_ns.doc(description="Create a completion message for text generation applications.")
@web_ns.expect(web_ns.models[CompletionMessagePayload.__name__])
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the completion", "type": "object", "required": True},
"query": {"description": "Query text for completion", "type": "string", "required": False},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.doc(
responses={
200: "Success",
@@ -88,10 +67,18 @@ class CompletionApi(WebApiResource):
if app_model.mode != AppMode.COMPLETION:
raise NotCompletionAppError()
payload = CompletionMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, location="json", default="")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
streaming = payload.response_mode == "streaming"
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
@@ -155,7 +142,22 @@ class CompletionStopApi(WebApiResource):
class ChatApi(WebApiResource):
@web_ns.doc("Create Chat Message")
@web_ns.doc(description="Create a chat message for conversational applications.")
@web_ns.expect(web_ns.models[ChatMessagePayload.__name__])
@web_ns.doc(
params={
"inputs": {"description": "Input variables for the chat", "type": "object", "required": True},
"query": {"description": "User query/message", "type": "string", "required": True},
"files": {"description": "Files to be processed", "type": "array", "required": False},
"response_mode": {
"description": "Response mode: blocking or streaming",
"type": "string",
"enum": ["blocking", "streaming"],
"required": False,
},
"conversation_id": {"description": "Conversation UUID", "type": "string", "required": False},
"parent_message_id": {"description": "Parent message UUID", "type": "string", "required": False},
"retriever_from": {"description": "Source of retriever", "type": "string", "required": False},
}
)
@web_ns.doc(
responses={
200: "Success",
@@ -171,10 +173,20 @@ class ChatApi(WebApiResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
payload = ChatMessagePayload.model_validate(web_ns.payload or {})
args = payload.model_dump(exclude_none=True)
parser = (
reqparse.RequestParser()
.add_argument("inputs", type=dict, required=True, location="json")
.add_argument("query", type=str, required=True, location="json")
.add_argument("files", type=list, required=False, location="json")
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
.add_argument("conversation_id", type=uuid_value, location="json")
.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
.add_argument("retriever_from", type=str, required=False, default="web_app", location="json")
)
streaming = payload.response_mode == "streaming"
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:

View File

@@ -1,4 +1,3 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: str | None = Field(default=None)
json_schema: dict[str, Any] | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: str | None) -> str | None:
def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None:
if schema is None:
return None
try:
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -176,13 +175,6 @@ class BaseAppGenerator:
value = True
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@@ -342,11 +342,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import exists, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -54,20 +54,6 @@ class MessageCycleManager:
):
self._application_generate_entity = application_generate_entity
self._task_state = task_state
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
if has_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
return StreamEvent.MESSAGE
def generate_conversation_name(self, *, conversation_id: str, query: str) -> Thread | None:
"""
@@ -228,11 +214,7 @@ class MessageCycleManager:
return None
def message_to_stream_response(
self,
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
event_type: StreamEvent | None = None,
self, answer: str, message_id: str, from_variable_selector: list[str] | None = None
) -> MessageStreamResponse:
"""
Message to stream response.
@@ -240,12 +222,16 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
event=event_type or StreamEvent.MESSAGE,
event=event_type,
)
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:

View File

@@ -39,7 +39,7 @@ from core.trigger.errors import (
plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL))
_plugin_daemon_timeout_config = cast(
float | httpx.Timeout | None,
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 600.0),
getattr(dify_config, "PLUGIN_DAEMON_TIMEOUT", 300.0),
)
plugin_daemon_request_timeout: httpx.Timeout | None
if _plugin_daemon_timeout_config is None:

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import codecs
import re
from typing import Any
@@ -53,7 +52,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
def __init__(self, fixed_separator: str = "\n\n", separators: list[str] | None = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._fixed_separator = codecs.decode(fixed_separator, "unicode_escape")
self._fixed_separator = fixed_separator
self._separators = separators or ["\n\n", "\n", "", ". ", " ", ""]
def split_text(self, text: str) -> list[str]:
@@ -95,8 +94,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = re.split(r" +", text)
else:
splits = text.split(separator)
if self._keep_separator:
splits = [s + separator for s in splits[:-1]] + splits[-1:]
splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
if separator == "\n":
@@ -105,7 +103,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = [s for s in splits if (s not in {"", "\n"})]
_good_splits = []
_good_splits_lengths = [] # cache the lengths of the splits
_separator = "" if self._keep_separator else separator
_separator = separator if self._keep_separator else ""
s_lens = self._length_function(splits)
if separator != "":
for s, s_len in zip(splits, s_lens):

View File

@@ -1,4 +1,3 @@
import json
from typing import Any
from jsonschema import Draft7Validator, ValidationError
@@ -43,25 +42,15 @@ class StartNode(Node[StartNodeData]):
if value is None and variable.required:
raise ValueError(f"{key} is required in input form")
if not isinstance(value, dict):
raise ValueError(f"{key} must be a JSON object")
schema = variable.json_schema
if not schema:
continue
if not value:
continue
try:
json_schema = json.loads(schema)
except json.JSONDecodeError as e:
raise ValueError(f"{schema} must be a valid JSON object")
try:
json_value = json.loads(value)
except json.JSONDecodeError as e:
raise ValueError(f"{value} must be a valid JSON object")
try:
Draft7Validator(json_schema).validate(json_value)
Draft7Validator(schema).validate(value)
except ValidationError as e:
raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}")
node_inputs[key] = json_value
node_inputs[key] = value

View File

@@ -4,6 +4,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clean_expired_sandbox_messages,
cleanup_orphaned_draft_variables,
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
@@ -54,6 +55,7 @@ def init_app(app: DifyApp):
setup_datasource_oauth_client,
transform_datasource_credentials,
install_rag_pipeline_plugins,
clean_expired_sandbox_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -11,7 +11,6 @@ from collections.abc import Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from uuid import UUID
from zoneinfo import available_timezones
from flask import Response, stream_with_context
@@ -120,19 +119,6 @@ def uuid_value(value: Any) -> str:
raise ValueError(error)
def normalize_uuid(value: str | UUID) -> str:
if not value:
return ""
try:
return uuid_value(value)
except ValueError as exc:
raise ValueError("must be a valid UUID") from exc
UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)]
def alphanumeric(value: str):
# check if the value is alphanumeric and underlined
if re.match(r"^[a-zA-Z0-9_]+$", value):

View File

@@ -1,29 +0,0 @@
"""remove unused is_deleted from conversations
Revision ID: e5d7a95e676f
Revises: d57accd375ae
Create Date: 2025-11-27 18:27:09.006691
"""
import sqlalchemy as sa
from alembic import op
revision = "e5d7a95e676f"
down_revision = "d57accd375ae"
branch_labels = None
depends_on = None
def upgrade():
conversations = sa.table("conversations", sa.column("is_deleted", sa.Boolean))
op.execute(sa.delete(conversations).where(conversations.c.is_deleted == sa.true()))
with op.batch_alter_table("conversations", schema=None) as batch_op:
batch_op.drop_column("is_deleted")
def downgrade():
with op.batch_alter_table("conversations", schema=None) as batch_op:
batch_op.add_column(
sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
)

View File

@@ -0,0 +1,33 @@
"""feat: add created_at id index to messages
Revision ID: 649d817a739e
Revises: 03ea244985ce
Create Date: 2025-12-18 16:39:33.090454
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '649d817a739e'
down_revision = '03ea244985ce'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.create_index('message_created_at_id_idx', ['created_at', 'id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('messages', schema=None) as batch_op:
batch_op.drop_index('message_created_at_id_idx')
# ### end Alembic commands ###

View File

@@ -676,6 +676,8 @@ class Conversation(Base):
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
)
is_deleted: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
@property
def inputs(self) -> dict[str, Any]:
inputs = self._inputs.copy()
@@ -963,6 +965,7 @@ class Message(Base):
Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
Index("message_created_at_idx", "created_at"),
Index("message_app_mode_idx", "app_mode"),
Index("message_created_at_id_idx", "created_at", "id"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))

View File

@@ -1,90 +1,54 @@
import datetime
import logging
import time
import click
from sqlalchemy.exc import SQLAlchemyError
import app
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.feature_service import FeatureService
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset")
@app.celery.task(queue="retention")
def clean_messages():
click.echo(click.style("Start clean messages.", fg="green"))
start_at = time.perf_counter()
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
)
while True:
try:
# Main query with join and filter
messages = (
db.session.query(Message)
.where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
)
"""
Clean expired messages from sandbox plan tenants.
except SQLAlchemyError:
raise
if not messages:
break
for message in messages:
app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key)
if plan_cache is None:
features = FeatureService.get_features(app.tenant_id)
redis_client.setex(features_cache_key, 600, features.billing.subscription.plan)
plan = features.billing.subscription.plan
else:
plan = plan_cache.decode()
if plan == CloudPlan.SANDBOX:
# clean related message
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style(f"Cleaned messages from db success latency: {end_at - start_at}", fg="green"))
This task uses SandboxMessagesCleanService to efficiently clean messages in batches.
"""
if not dify_config.BILLING_ENABLED:
click.echo(click.style("Billing is not enabled. Skipping sandbox messages cleanup.", fg="yellow"))
return
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
stats = SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
graceful_period=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Messages found: {stats['total_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise

View File

@@ -47,6 +47,7 @@ class ConversationService:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
stmt = select(Conversation).where(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
@@ -165,6 +166,7 @@ class ConversationService:
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
Conversation.is_deleted == False,
)
.first()
)

View File

@@ -23,7 +23,7 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: list[dict[str, str]] | None = None
partial_member_list: list[str] | None = None
yaml_content: str | None = None

View File

@@ -0,0 +1,488 @@
import datetime
import json
import logging
from collections.abc import Sequence
from dataclasses import dataclass
from typing import cast
from sqlalchemy import delete, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session
from configs import dify_config
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import (
App,
AppAnnotationHitHistory,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.billing_service import BillingService, SubscriptionPlan
logger = logging.getLogger(__name__)
@dataclass
class SimpleMessage:
"""Lightweight message info containing only essential fields for cleaning."""
id: str
app_id: str
created_at: datetime.datetime
class SandboxMessagesCleanService:
"""
Service for cleaning expired messages from sandbox plan tenants.
"""
# Redis key prefix for tenant plan cache
PLAN_CACHE_KEY_PREFIX = "tenant_plan:"
# Cache TTL: 10 minutes
PLAN_CACHE_TTL = 600
@classmethod
def clean_sandbox_messages_by_time_range(
cls,
start_from: datetime.datetime,
end_before: datetime.datetime,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Clean sandbox messages within a specific time range [start_from, end_before).
Args:
start_from: Start time (inclusive) of the range
end_before: End time (exclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Statistics about the cleaning operation
Raises:
ValueError: If start_from >= end_before
"""
if start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be less than end_before ({end_before})")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
if graceful_period < 0:
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
logger.info("clean_messages: start_from=%s, end_before=%s, batch_size=%s", start_from, end_before, batch_size)
return cls._clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def clean_sandbox_messages_by_days(
cls,
days: int = 30,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Clean sandbox messages older than specified days.
Args:
days: Number of days to look back from now
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Statistics about the cleaning operation
"""
if days < 0:
raise ValueError(f"days ({days}) must be greater than or equal to 0")
if batch_size <= 0:
raise ValueError(f"batch_size ({batch_size}) must be greater than 0")
if graceful_period < 0:
raise ValueError(f"graceful_period ({graceful_period}) must be greater than or equal to 0")
end_before = datetime.datetime.now() - datetime.timedelta(days=days)
logger.info("clean_messages: days=%s, end_before=%s, batch_size=%s", days, end_before, batch_size)
return cls._clean_sandbox_messages_by_time_range(
end_before=end_before,
start_from=None,
graceful_period=graceful_period,
batch_size=batch_size,
dry_run=dry_run,
)
@classmethod
def _clean_sandbox_messages_by_time_range(
cls,
end_before: datetime.datetime,
start_from: datetime.datetime | None = None,
graceful_period: int = 21,
batch_size: int = 1000,
dry_run: bool = False,
) -> dict[str, int]:
"""
Internal method to clean sandbox messages within a time range using cursor-based pagination.
Time range is [start_from, end_before) - left-closed, right-open interval.
Steps:
1. Iterate messages using cursor pagination (by created_at, id)
2. Extract app_ids from messages
3. Query tenant_ids from apps
4. Batch fetch subscription plans
5. Delete messages from sandbox tenants
Args:
end_before: End time (exclusive) of the range
start_from: Optional start time (inclusive) of the range
batch_size: Number of messages to process per batch
dry_run: Whether to perform a dry run (no actual deletion)
Returns:
Dict with statistics: batches, total_messages, total_deleted
"""
stats = {
"batches": 0,
"total_messages": 0,
"total_deleted": 0,
}
if not dify_config.BILLING_ENABLED:
logger.info("clean_messages: billing is not enabled, skip cleaning messages")
return stats
tenant_whitelist = cls._get_tenant_whitelist()
logger.info("clean_messages: tenant_whitelist=%s", tenant_whitelist)
# Cursor-based pagination using (created_at, id) to avoid infinite loops
# and ensure proper ordering with time-based filtering
_cursor: tuple[datetime.datetime, str] | None = None
logger.info(
"clean_messages: start cleaning messages (dry_run=%s), start_from=%s, end_before=%s",
dry_run,
start_from,
end_before,
)
while True:
stats["batches"] += 1
# Step 1: Fetch a batch of messages using cursor
with Session(db.engine, expire_on_commit=False) as session:
msg_stmt = (
select(Message.id, Message.app_id, Message.created_at)
.where(Message.created_at < end_before)
.order_by(Message.created_at, Message.id)
.limit(batch_size)
)
if start_from:
msg_stmt = msg_stmt.where(Message.created_at >= start_from)
# Apply cursor condition: (created_at, id) > (last_created_at, last_message_id)
# This translates to:
# created_at > last_created_at OR (created_at = last_created_at AND id > last_message_id)
if _cursor:
# Continuing from previous batch
msg_stmt = msg_stmt.where(
(Message.created_at > _cursor[0])
| ((Message.created_at == _cursor[0]) & (Message.id > _cursor[1]))
)
raw_messages = list(session.execute(msg_stmt).all())
messages = [
SimpleMessage(id=msg_id, app_id=app_id, created_at=msg_created_at)
for msg_id, app_id, msg_created_at in raw_messages
]
if not messages:
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
break
# Update cursor to the last message's (created_at, id)
_cursor = (messages[-1].created_at, messages[-1].id)
# Step 2: Extract app_ids from this batch
app_ids = list({msg.app_id for msg in messages})
if not app_ids:
logger.info("clean_messages (batch %s): no app_ids found, skip", stats["batches"])
continue
# Step 3: Query tenant_ids from apps
app_stmt = select(App.id, App.tenant_id).where(App.id.in_(app_ids))
apps = list(session.execute(app_stmt).all())
if not apps:
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
continue
# Step 4: End sesion to call billing API to avoid long-running transaction.
# Build app_id -> tenant_id mapping
app_to_tenant: dict[str, str] = {app.id: app.tenant_id for app in apps}
tenant_ids = list(set(app_to_tenant.values()))
# Batch fetch subscription plans
tenant_plans = cls._batch_fetch_tenant_plans(tenant_ids)
# Step 5: Filter messages from sandbox tenants
sandbox_message_ids = cls._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=graceful_period,
)
if not sandbox_message_ids:
logger.info("clean_messages (batch %s): no sandbox messages found, skip", stats["batches"])
continue
stats["total_messages"] += len(sandbox_message_ids)
# Step 6: Batch delete messages and their relations
if not dry_run:
with Session(db.engine, expire_on_commit=False) as session:
# Delete related records first
cls._batch_delete_message_relations(session, sandbox_message_ids)
# Delete messages
delete_stmt = delete(Message).where(Message.id.in_(sandbox_message_ids))
delete_result = cast(CursorResult, session.execute(delete_stmt))
messages_deleted = delete_result.rowcount
session.commit()
stats["total_deleted"] += messages_deleted
logger.info(
"clean_messages (batch %s): processed %s messages, deleted %s sandbox messages",
stats["batches"],
len(messages),
messages_deleted,
)
else:
sample_ids = ", ".join(sample_id for sample_id in sandbox_message_ids[:5])
logger.info(
"clean_messages (batch %s, dry_run): would delete %s sandbox messages, sample ids: %s",
stats["batches"],
len(sandbox_message_ids),
sample_ids,
)
logger.info(
"clean_messages completed: total batches: %s, total messages: %s, total deleted: %s",
stats["batches"],
stats["total_messages"],
stats["total_deleted"],
)
return stats
@classmethod
def _filter_expired_sandbox_messages(
cls,
messages: Sequence[SimpleMessage],
app_to_tenant: dict[str, str],
tenant_plans: dict[str, SubscriptionPlan],
tenant_whitelist: Sequence[str],
graceful_period_days: int,
current_timestamp: int | None = None,
) -> list[str]:
"""
Filter messages that should be deleted based on sandbox plan expiration.
A message should be deleted if:
1. It belongs to a sandbox tenant AND
2. Either:
a) The tenant has no previous subscription (expiration_date == -1), OR
b) The subscription expired more than graceful_period_days ago
Args:
messages: List of message objects with id and app_id attributes
app_to_tenant: Mapping from app_id to tenant_id
tenant_plans: Mapping from tenant_id to subscription plan info
graceful_period_days: Grace period in days after expiration
current_timestamp: Current Unix timestamp (defaults to now, injectable for testing)
Returns:
List of message IDs that should be deleted
"""
if current_timestamp is None:
current_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
sandbox_message_ids: list[str] = []
graceful_period_seconds = graceful_period_days * 24 * 60 * 60
for msg in messages:
# Get tenant_id for this message's app
tenant_id = app_to_tenant.get(msg.app_id)
if not tenant_id:
continue
# Skip tenant messages in whitelist
if tenant_id in tenant_whitelist:
continue
# Get subscription plan for this tenant
tenant_plan = tenant_plans.get(tenant_id)
if not tenant_plan:
continue
plan = str(tenant_plan["plan"])
expiration_date = int(tenant_plan["expiration_date"])
# Only process sandbox plans
if plan != CloudPlan.SANDBOX:
continue
# Case 1: No previous subscription (-1 means never had a paid subscription)
if expiration_date == -1:
sandbox_message_ids.append(msg.id)
continue
# Case 2: Subscription expired beyond grace period
if current_timestamp - expiration_date > graceful_period_seconds:
sandbox_message_ids.append(msg.id)
return sandbox_message_ids
@classmethod
def _get_tenant_whitelist(cls) -> Sequence[str]:
return BillingService.get_expired_subscription_cleanup_whitelist()
@classmethod
def _batch_fetch_tenant_plans(cls, tenant_ids: Sequence[str]) -> dict[str, SubscriptionPlan]:
"""
Batch fetch tenant plans with Redis caching.
This method uses a two-tier strategy:
1. First, batch fetch from Redis cache using mget
2. For cache misses, fetch from billing API
3. Update Redis cache using pipeline for new entries
Args:
tenant_ids: List of tenant IDs
Returns:
Dict mapping tenant_id to SubscriptionPlan (with "plan" and "expiration_date" keys)
"""
if not tenant_ids:
return {}
tenant_plans: dict[str, SubscriptionPlan] = {}
# Step 1: Batch fetch from Redis cache using mget
redis_keys = [f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}" for tenant_id in tenant_ids]
try:
cached_values = redis_client.mget(redis_keys)
# Map cached values back to tenant_ids
cache_hits: dict[str, SubscriptionPlan] = {}
cache_misses: list[str] = []
for tenant_id, cached_value in zip(tenant_ids, cached_values):
if cached_value:
# Redis returns bytes, decode to string and parse JSON
json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value
try:
plan_dict = json.loads(json_str)
if isinstance(plan_dict, dict) and "plan" in plan_dict:
cache_hits[tenant_id] = cast(SubscriptionPlan, plan_dict)
tenant_plans[tenant_id] = cast(SubscriptionPlan, plan_dict)
else:
cache_misses.append(tenant_id)
except json.JSONDecodeError:
cache_misses.append(tenant_id)
else:
cache_misses.append(tenant_id)
logger.info(
"clean_messages: fetch_tenant_plans(cache hits=%s, cache misses=%s)",
len(cache_hits),
len(cache_misses),
)
except Exception as e:
logger.warning("clean_messages: fetch_tenant_plans(redis mget failed: %s, falling back to API)", e)
cache_misses = list(tenant_ids)
# Step 2: Fetch missing plans from billing API
if cache_misses:
bulk_plans = BillingService.get_plan_bulk(cache_misses)
if bulk_plans:
plans_to_cache: dict[str, SubscriptionPlan] = {}
for tenant_id, plan_dict in bulk_plans.items():
if isinstance(plan_dict, dict):
tenant_plans[tenant_id] = plan_dict # type: ignore
plans_to_cache[tenant_id] = plan_dict # type: ignore
# Step 3: Batch update Redis cache using pipeline
if plans_to_cache:
try:
pipe = redis_client.pipeline()
for tenant_id, plan_dict in plans_to_cache.items():
redis_key = f"{cls.PLAN_CACHE_KEY_PREFIX}{tenant_id}"
# Serialize dict to JSON string
json_str = json.dumps(plan_dict)
pipe.setex(redis_key, cls.PLAN_CACHE_TTL, json_str)
pipe.execute()
logger.info(
"clean_messages: cached %s new tenant plans to Redis",
len(plans_to_cache),
)
except Exception as e:
logger.warning("clean_messages: Redis pipeline failed: %s", e)
return tenant_plans
@classmethod
def _batch_delete_message_relations(cls, session: Session, message_ids: Sequence[str]) -> None:
"""
Batch delete all related records for given message IDs.
Args:
session: Database session
message_ids: List of message IDs to delete relations for
"""
if not message_ids:
return
# Delete all related records in batch
session.execute(delete(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)))
session.execute(delete(MessageAnnotation).where(MessageAnnotation.message_id.in_(message_ids)))
session.execute(delete(MessageChain).where(MessageChain.message_id.in_(message_ids)))
session.execute(delete(MessageAgentThought).where(MessageAgentThought.message_id.in_(message_ids)))
session.execute(delete(MessageFile).where(MessageFile.message_id.in_(message_ids)))
session.execute(delete(SavedMessage).where(SavedMessage.message_id.in_(message_ids)))
session.execute(delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id.in_(message_ids)))
session.execute(delete(DatasetRetrieverResource).where(DatasetRetrieverResource.message_id.in_(message_ids)))

View File

@@ -0,0 +1,996 @@
"""
Integration tests for SandboxMessagesCleanService using testcontainers.
This module provides comprehensive integration tests for the sandbox message cleanup service
using TestContainers infrastructure with real PostgreSQL and Redis.
"""
import datetime
import json
import uuid
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
class TestSandboxMessagesCleanServiceIntegration:
"""Integration tests for SandboxMessagesCleanService._clean_sandbox_messages_by_time_range."""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before and after each test to ensure isolation."""
yield
# Clear all test data in correct order (respecting foreign key constraints)
db.session.query(DatasetRetrieverResource).delete()
db.session.query(AppAnnotationHitHistory).delete()
db.session.query(SavedMessage).delete()
db.session.query(MessageFile).delete()
db.session.query(MessageAgentThought).delete()
db.session.query(MessageChain).delete()
db.session.query(MessageAnnotation).delete()
db.session.query(MessageFeedback).delete()
db.session.query(Message).delete()
db.session.query(Conversation).delete()
db.session.query(App).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
@pytest.fixture(autouse=True)
def cleanup_redis(self):
"""Clean up Redis cache before each test."""
# Clear tenant plan cache
try:
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
if keys:
redis_client.delete(*keys)
except Exception:
pass # Redis might not be available in some test environments
yield
# Clean up after test
try:
keys = redis_client.keys(f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}*")
if keys:
redis_client.delete(*keys)
except Exception:
pass
@pytest.fixture(autouse=True)
def mock_whitelist(self):
"""Mock whitelist to return empty list by default."""
with patch(
"services.sandbox_messages_clean_service.BillingService.get_expired_subscription_cleanup_whitelist"
) as mock:
mock.return_value = []
yield mock
@pytest.fixture(autouse=True)
def mock_billing_enabled(self):
"""Mock BILLING_ENABLED to be True for all tests."""
with patch("services.sandbox_messages_clean_service.dify_config.BILLING_ENABLED", True):
yield
def _create_account_and_tenant(self, plan="sandbox"):
"""Helper to create account and tenant."""
fake = Faker()
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.flush()
tenant = Tenant(
name=fake.company(),
plan=plan,
status="normal",
)
db.session.add(tenant)
db.session.flush()
tenant_account_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
return account, tenant
def _create_app(self, tenant, account):
"""Helper to create an app."""
fake = Faker()
app = App(
tenant_id=tenant.id,
name=fake.company(),
description="Test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
db.session.add(app)
db.session.commit()
return app
def _create_conversation(self, app):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-3.5-turbo",
mode="chat",
name="Test conversation",
inputs={},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(conversation)
db.session.commit()
return conversation
def _create_message(self, app, conversation, created_at=None, with_relations=True):
"""Helper to create a message with optional related records."""
if created_at is None:
created_at = datetime.datetime.now()
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-3.5-turbo",
inputs={},
query="Test query",
answer="Test answer",
message=[{"role": "user", "text": "Test message"}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source="api",
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
db.session.add(message)
db.session.flush()
if with_relations:
self._create_message_relations(message)
db.session.commit()
return message
def _create_message_relations(self, message):
"""Helper to create all message-related records."""
# MessageFeedback
feedback = MessageFeedback(
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
rating="like",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(feedback)
# MessageAnnotation
annotation = MessageAnnotation(
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
question="Test question",
content="Test annotation",
account_id=message.from_account_id,
)
db.session.add(annotation)
# MessageChain
chain = MessageChain(
message_id=message.id,
type="system",
input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}),
)
db.session.add(chain)
db.session.flush()
# MessageFile
file = MessageFile(
message_id=message.id,
type="image",
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to="user",
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(file)
# SavedMessage
saved = SavedMessage(
app_id=message.app_id,
message_id=message.id,
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(saved)
db.session.flush()
# AppAnnotationHitHistory
hit = AppAnnotationHitHistory(
app_id=message.app_id,
annotation_id=annotation.id,
message_id=message.id,
source="annotation",
question="Test question",
account_id=message.from_account_id,
annotation_question="Test annotation question",
annotation_content="Test annotation content",
)
db.session.add(hit)
# DatasetRetrieverResource
resource = DatasetRetrieverResource(
message_id=message.id,
position=1,
dataset_id=str(uuid.uuid4()),
dataset_name="Test dataset",
document_id=str(uuid.uuid4()),
document_name="Test document",
data_source_type="upload_file",
segment_id=str(uuid.uuid4()),
score=0.9,
content="Test content",
hit_count=1,
word_count=10,
segment_position=1,
index_node_hash="test_hash",
retriever_from="dataset",
created_by=message.from_account_id,
)
db.session.add(resource)
def test_clean_no_messages_to_delete(self, db_session_with_containers):
"""Test cleaning when there are no messages to delete."""
# Arrange
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {}
# Act
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert
# Even with no messages, the loop runs once to check
assert stats["batches"] == 1
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
def test_clean_mixed_sandbox_and_paid_tenants(self, db_session_with_containers):
"""Test cleaning with mixed sandbox and paid tenants, correctly filtering sandbox messages."""
# Arrange - Create sandbox tenants with expired messages
sandbox_tenants = []
sandbox_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan="sandbox")
sandbox_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 3 expired messages per sandbox tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
sandbox_message_ids.append(msg.id)
# Create paid tenants with expired messages (should NOT be deleted)
paid_tenants = []
paid_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan="professional")
paid_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 2 expired messages per paid tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(2):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
paid_message_ids.append(msg.id)
# Mock billing service - return plan and expiration_date
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
expired_15_days_ago = now_timestamp - (15 * 24 * 60 * 60) # Beyond 7-day grace period
plan_map = {}
for tenant in sandbox_tenants:
plan_map[tenant.id] = {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_15_days_ago,
}
for tenant in paid_tenants:
plan_map[tenant.id] = {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=7,
batch_size=100,
)
# Assert
assert stats["total_messages"] == 6 # 2 sandbox tenants * 3 messages
assert stats["total_deleted"] == 6
# Only sandbox messages should be deleted
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
# Paid messages should remain
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
# Related records of sandbox messages should be deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
assert (
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count()
== 0
)
def test_clean_with_cursor_pagination(self, db_session_with_containers):
"""Test cursor pagination works correctly across multiple batches."""
# Arrange - Create sandbox tenant with messages that will span multiple batches
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create 10 expired messages with different timestamps
base_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(10):
msg = self._create_message(
app,
conv,
created_at=base_date + datetime.timedelta(hours=i),
with_relations=False, # Skip relations for speed
)
message_ids.append(msg.id)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act - Use small batch size to trigger multiple batches
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=3, # Small batch size to test pagination
)
# 5 batches for 10 messages with batch_size=3, the last batch is empty
assert stats["batches"] == 5
assert stats["total_messages"] == 10
assert stats["total_deleted"] == 10
# All messages should be deleted
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0
def test_clean_with_dry_run(self, db_session_with_containers):
"""Test dry_run mode does not delete messages."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create expired messages
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
message_ids.append(msg.id)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
dry_run=True, # Dry run mode
)
# Assert
assert stats["total_messages"] == 3 # Messages identified
assert stats["total_deleted"] == 0 # But NOT deleted
# All messages should still exist
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3
# Related records should also still exist
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3
def test_clean_with_billing_partial_exception_some_known_plans(self, db_session_with_containers):
"""Test when billing service fails but returns partial data, only delete known sandbox messages."""
# Arrange - Create 3 tenants
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date)
tenants_data.append(
{
"tenant": tenant,
"message_id": msg.id,
}
)
# Mock billing service to return partial data with new structure
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Only tenant[0] is confirmed as sandbox, tenant[1] is professional, tenant[2] is missing
partial_plan_map = {
tenants_data[0]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[1]["tenant"].id: {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": now_timestamp + (365 * 24 * 60 * 60), # Active for 1 year
},
# tenants_data[2] is missing from response
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = partial_plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only tenant[0]'s message should be deleted
assert stats["total_messages"] == 1
assert stats["total_deleted"] == 1
# Check which messages were deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
) # Sandbox tenant's message deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
) # Professional tenant's message preserved
assert (
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
) # Unknown tenant's message preserved (safe default)
def test_clean_with_billing_exception_no_data(self, db_session_with_containers):
"""Test when billing service returns empty data, skip deletion for that batch."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_id = None
msg = self._create_message(app, conv, created_at=expired_date)
msg_id = msg.id # Store ID before any operations
db.session.commit()
# Mock billing service to return empty data (simulating failure/no data scenario)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {} # Empty response, tenant plan unknown
# Act - Should not raise exception, just skip deletion
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - No messages should be deleted when plan is unknown
assert stats["total_messages"] == 0 # Cannot determine sandbox messages
assert stats["total_deleted"] == 0
# Message should still exist (safe default - don't delete if plan is unknown)
assert db.session.query(Message).where(Message.id == msg_id).count() == 1
def test_redis_cache_for_tenant_plans(self, db_session_with_containers):
"""Test that tenant plans are cached in Redis and reused."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
# Create messages in two batches (to test cache reuse)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
batch1_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=i), with_relations=False
)
batch1_msgs.append(msg.id)
batch2_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=10 + i), with_relations=False
)
batch2_msgs.append(msg.id)
# Mock billing service with new structure
mock_get_plan_bulk = MagicMock(
return_value={
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
)
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk", mock_get_plan_bulk):
# Act - First call
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats1 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=2, # Process 2 messages per batch
)
# Check billing service was called (cache miss)
assert mock_get_plan_bulk.call_count == 1
first_call_count = mock_get_plan_bulk.call_count
# Verify Redis cache was populated
cache_key = f"{SandboxMessagesCleanService.PLAN_CACHE_KEY_PREFIX}{tenant.id}"
cached_plan = redis_client.get(cache_key)
assert cached_plan is not None
cached_plan_data = json.loads(cached_plan.decode("utf-8"))
assert cached_plan_data["plan"] == CloudPlan.SANDBOX
assert cached_plan_data["expiration_date"] == -1
# Act - Second call with same tenant (should use cache)
# Create more messages for the same tenant
batch3_msgs = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=expired_date + datetime.timedelta(hours=20 + i), with_relations=False
)
batch3_msgs.append(msg.id)
stats2 = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=2,
)
# Assert - Billing service should not be called again (cache hit)
# The call count should be the same
assert mock_get_plan_bulk.call_count == first_call_count # Same tenant, should use cache
# Verify all messages were deleted
total_expected = len(batch1_msgs) + len(batch2_msgs) + len(batch3_msgs)
assert stats1["total_deleted"] + stats2["total_deleted"] == total_expected
def test_time_range_filtering(self, db_session_with_containers):
"""Test that messages are correctly filtered by [start_from, end_before) time range."""
# Arrange
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
base_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
# Create messages: before range, in range, after range
msg_before = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
with_relations=False,
)
msg_before_id = msg_before.id
msg_at_start = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
with_relations=False,
)
msg_at_start_id = msg_at_start.id
msg_in_range = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
with_relations=False,
)
msg_in_range_id = msg_in_range.id
msg_at_end = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
with_relations=False,
)
msg_at_end_id = msg_at_end.id
msg_after = self._create_message(
app,
conv,
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
with_relations=False,
)
msg_after_id = msg_after.id
db.session.commit() # Commit all messages
# Mock billing service with new structure
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Act - Clean with specific time range [2024-01-10, 2024-01-20)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
start_from=datetime.datetime(2024, 1, 10, 12, 0, 0),
end_before=datetime.datetime(2024, 1, 20, 12, 0, 0),
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only messages in [start_from, end_before) should be deleted
assert stats["total_messages"] == 2 # msg_at_start and msg_in_range
assert stats["total_deleted"] == 2
# Verify specific messages using stored IDs
# Before range, kept
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1
# At start (inclusive), deleted
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0
# In range, deleted
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0
# At end (exclusive), kept
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1
# After range, kept
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1
def test_clean_with_graceful_period_scenarios(self, db_session_with_containers):
"""Test cleaning with different graceful period scenarios."""
# Arrange - Create 5 different tenants with different plan and expiration scenarios
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
graceful_period = 8 # Use 8 days for this test to validate boundary conditions
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
# Should NOT be deleted
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
msg1_id = msg1.id # Save ID before potential deletion
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
# Should be deleted
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
msg2_id = msg2.id # Save ID before potential deletion
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
# Should be deleted
account3, tenant3 = self._create_account_and_tenant(plan="sandbox")
app3 = self._create_app(tenant3, account3)
conv3 = self._create_conversation(app3)
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False)
msg3_id = msg3.id # Save ID before potential deletion
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
# Should NOT be deleted
account4, tenant4 = self._create_account_and_tenant(plan="professional")
app4 = self._create_app(tenant4, account4)
conv4 = self._create_conversation(app4)
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False)
msg4_id = msg4.id # Save ID before potential deletion
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
# Should NOT be deleted (boundary is exclusive: > graceful_period)
account5, tenant5 = self._create_account_and_tenant(plan="sandbox")
app5 = self._create_app(tenant5, account5)
conv5 = self._create_conversation(app5)
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False)
msg5_id = msg5.id # Save ID before potential deletion
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
db.session.commit()
# Mock billing service with all scenarios
plan_map = {
tenant1.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_5_days_ago,
},
tenant2.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_10_days_ago,
},
tenant3.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1,
},
tenant4.id: {
"plan": CloudPlan.PROFESSIONAL,
"expiration_date": future_expiration,
},
tenant5.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_exactly_8_days_ago,
},
}
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
# Mock datetime.now() to use the same timestamp as test setup
# This ensures deterministic behavior for boundary conditions (scenario 5)
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
mock_datetime.datetime.now.return_value = datetime.datetime.fromtimestamp(
now_timestamp, tz=datetime.UTC
)
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=graceful_period,
batch_size=100,
)
# Assert - Only messages from scenario 2 and 3 should be deleted
assert stats["total_messages"] == 2
assert stats["total_deleted"] == 2
# Verify each scenario using saved IDs
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
def test_clean_with_tenant_whitelist(self, db_session_with_containers, mock_whitelist):
"""Test that whitelisted tenants' messages are not deleted even if they are sandbox and expired."""
# Arrange - Create 3 sandbox tenants with expired messages
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False)
tenants_data.append(
{
"tenant": tenant,
"message_id": msg.id,
}
)
# Mock billing service - all tenants are sandbox with no subscription
plan_map = {
tenants_data[0]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[1]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
tenants_data[2]["tenant"].id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
},
}
# Setup whitelist - tenant0 and tenant1 are whitelisted, tenant2 is not
whitelist = [tenants_data[0]["tenant"].id, tenants_data[1]["tenant"].id]
mock_whitelist.return_value = whitelist
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - Only tenant2's message should be deleted (not whitelisted)
assert stats["total_messages"] == 1
assert stats["total_deleted"] == 1
# Verify tenant0's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
# Verify tenant1's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
# Verify tenant2's message was deleted (not whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
def test_clean_with_whitelist_and_grace_period(self, db_session_with_containers, mock_whitelist):
"""Test that whitelist takes precedence over grace period logic."""
# Arrange - Create 2 sandbox tenants
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Tenant1: whitelisted, expired beyond grace period
account1, tenant1 = self._create_account_and_tenant(plan="sandbox")
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
# Tenant2: not whitelisted, within grace period
account2, tenant2 = self._create_account_and_tenant(plan="sandbox")
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
# Mock billing service
plan_map = {
tenant1.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_30_days_ago, # Beyond grace period
},
tenant2.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": expired_10_days_ago, # Within grace period
},
}
# Setup whitelist - only tenant1 is whitelisted
whitelist = [tenant1.id]
mock_whitelist.return_value = whitelist
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - No messages should be deleted
# tenant1: whitelisted (would be deleted based on grace period, but protected by whitelist)
# tenant2: within grace period (not eligible for deletion)
assert stats["total_messages"] == 0
assert stats["total_deleted"] == 0
# Verify both messages still exist
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period
def test_clean_with_empty_whitelist(self, db_session_with_containers, mock_whitelist):
"""Test that empty whitelist behaves as no whitelist (all eligible messages are deleted)."""
# Arrange - Create sandbox tenant with expired messages
account, tenant = self._create_account_and_tenant(plan="sandbox")
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
msg_ids.append(msg.id)
# Mock billing service
plan_map = {
tenant.id: {
"plan": CloudPlan.SANDBOX,
"expiration_date": -1, # No previous subscription
}
}
# Setup empty whitelist (default behavior from fixture)
mock_whitelist.return_value = []
with patch("services.sandbox_messages_clean_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = plan_map
# Act
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
stats = SandboxMessagesCleanService._clean_sandbox_messages_by_time_range(
end_before=end_before,
graceful_period=21, # Use default graceful period
batch_size=100,
)
# Assert - All messages should be deleted (no whitelist protection)
assert stats["total_messages"] == 3
assert stats["total_deleted"] == 3
# Verify all messages were deleted
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0

View File

@@ -149,6 +149,7 @@ class TestWebConversationService:
from_end_user_id=user.id if isinstance(user, EndUser) else None,
from_account_id=user.id if isinstance(user, Account) else None,
dialogue_count=0,
is_deleted=False,
)
from extensions.ext_database import db

View File

@@ -2,9 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import TypeAdapter, ValidationError
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant
from models.tools import ApiToolProvider
from services.tools.api_tools_manage_service import ApiToolManageService
@@ -300,7 +298,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none", "api_key_header": "X-API-Key", "api_key_value": ""}
schema_type = ApiProviderSchemaType.OPENAPI
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@@ -366,7 +364,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {"auth_type": "none"}
schema_type = ApiProviderSchemaType.OPENAPI
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@@ -430,10 +428,21 @@ class TestApiToolManageService:
labels = ["test"]
# Act & Assert: Try to create provider with invalid schema type
with pytest.raises(ValidationError) as exc_info:
TypeAdapter(ApiProviderSchemaType).validate_python(schema_type)
with pytest.raises(ValueError) as exc_info:
ApiToolManageService.create_api_tool_provider(
user_id=account.id,
tenant_id=tenant.id,
provider_name=provider_name,
icon=icon,
credentials=credentials,
schema_type=schema_type,
schema=schema,
privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer,
labels=labels,
)
assert "validation error" in str(exc_info.value)
assert "invalid schema type" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
@@ -455,7 +464,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔧"}
credentials = {} # Missing auth_type
schema_type = ApiProviderSchemaType.OPENAPI
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"
@@ -498,7 +507,7 @@ class TestApiToolManageService:
provider_name = fake.company()
icon = {"type": "emoji", "value": "🔑"}
credentials = {"auth_type": "api_key", "api_key_header": "X-API-Key", "api_key_value": fake.uuid4()}
schema_type = ApiProviderSchemaType.OPENAPI
schema_type = "openapi"
schema = self._create_test_openapi_schema()
privacy_policy = "https://example.com/privacy"
custom_disclaimer = "Custom disclaimer text"

View File

@@ -1,420 +0,0 @@
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueErrorEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
QueuePingEvent,
)
from core.app.entities.task_entities import (
EasyUITaskState,
ErrorStreamResponse,
MessageEndStreamResponse,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
PingStreamResponse,
StreamEvent,
)
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TestEasyUIBasedGenerateTaskPipelineProcessStreamResponse:
"""Test cases for EasyUIBasedGenerateTaskPipeline._process_stream_response method."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock(spec=ChatAppGenerateEntity)
entity.task_id = "test-task-id"
entity.app_id = "test-app-id"
# minimal app_config used by pipeline internals
entity.app_config = SimpleNamespace(
tenant_id="test-tenant-id",
app_id="test-app-id",
app_mode=AppMode.CHAT,
app_model_config_dict={},
additional_features=None,
sensitive_word_avoidance=None,
)
# minimal model_conf for LLMResult init
entity.model_conf = SimpleNamespace(
model="test-model",
provider_model_bundle=SimpleNamespace(model_type_instance=Mock()),
credentials={},
)
return entity
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = Mock(spec=AppQueueManager)
return manager
@pytest.fixture
def mock_message_cycle_manager(self):
"""Create a mock message cycle manager."""
manager = Mock()
manager.get_message_event_type.return_value = StreamEvent.MESSAGE
manager.message_to_stream_response.return_value = Mock(spec=MessageStreamResponse)
manager.message_file_to_stream_response.return_value = Mock(spec=MessageFileStreamResponse)
manager.message_replace_to_stream_response.return_value = Mock(spec=MessageReplaceStreamResponse)
manager.handle_retriever_resources = Mock()
manager.handle_annotation_reply.return_value = None
return manager
@pytest.fixture
def mock_conversation(self):
"""Create a mock conversation."""
conversation = Mock()
conversation.id = "test-conversation-id"
conversation.mode = "chat"
return conversation
@pytest.fixture
def mock_message(self):
"""Create a mock message."""
message = Mock()
message.id = "test-message-id"
message.created_at = Mock()
message.created_at.timestamp.return_value = 1234567890
return message
@pytest.fixture
def mock_task_state(self):
"""Create a mock task state."""
task_state = Mock(spec=EasyUITaskState)
# Create LLM result mock
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.prompt_messages = []
llm_result.message = Mock()
llm_result.message.content = ""
task_state.llm_result = llm_result
task_state.answer = ""
return task_state
@pytest.fixture
def pipeline(
self,
mock_application_generate_entity,
mock_queue_manager,
mock_conversation,
mock_message,
mock_message_cycle_manager,
mock_task_state,
):
"""Create an EasyUIBasedGenerateTaskPipeline instance with mocked dependencies."""
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.EasyUITaskState", return_value=mock_task_state
):
pipeline = EasyUIBasedGenerateTaskPipeline(
application_generate_entity=mock_application_generate_entity,
queue_manager=mock_queue_manager,
conversation=mock_conversation,
message=mock_message,
stream=True,
)
pipeline._message_cycle_manager = mock_message_cycle_manager
pipeline._task_state = mock_task_state
return pipeline
def test_get_message_event_type_called_once_when_first_llm_chunk_arrives(
self, pipeline, mock_message_cycle_manager
):
"""Expect get_message_event_type to be called when processing the first LLM chunk event."""
# Setup a minimal LLM chunk event
chunk = Mock()
chunk.delta.message.content = "hi"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
mock_message_cycle_manager.get_message_event_type.assert_called_once_with(message_id="test-message-id")
def test_llm_chunk_event_with_text_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with text content."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Hello, world!"
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello, world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello, world!"
def test_llm_chunk_event_with_list_content(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of LLM chunk events with list content."""
# Setup
text_content = Mock(spec=TextPromptMessageContent)
text_content.data = "Hello"
chunk = Mock()
chunk.delta.message.content = [text_content, " world!"]
chunk.prompt_messages = []
llm_chunk_event = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = llm_chunk_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
mock_message_cycle_manager.message_to_stream_response.assert_called_once_with(
answer="Hello world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
assert mock_task_state.llm_result.message.content == "Hello world!"
def test_agent_message_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of agent message events."""
# Setup
chunk = Mock()
chunk.delta.message.content = "Agent response"
agent_message_event = Mock(spec=QueueAgentMessageEvent)
agent_message_event.chunk = chunk
mock_queue_message = Mock()
mock_queue_message.event = agent_message_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
# Ensure method under assertion is a mock to track calls
pipeline._agent_message_to_stream_response = Mock(return_value=Mock())
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
# Agent messages should use _agent_message_to_stream_response
pipeline._agent_message_to_stream_response.assert_called_once_with(
answer="Agent response", message_id="test-message-id"
)
def test_message_end_event(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling of message end events."""
# Setup
llm_result = Mock(spec=RuntimeLLMResult)
llm_result.message = Mock()
llm_result.message.content = "Final response"
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = llm_result
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert mock_task_state.llm_result == llm_result
pipeline._save_message.assert_called_once()
pipeline._message_end_to_stream_response.assert_called_once()
def test_error_event(self, pipeline):
"""Test handling of error events."""
# Setup
error_event = Mock(spec=QueueErrorEvent)
error_event.error = Exception("Test error")
mock_queue_message = Mock()
mock_queue_message.event = error_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.handle_error = Mock(return_value=Exception("Test error"))
pipeline.error_to_stream_response = Mock(return_value=Mock(spec=ErrorStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.handle_error.assert_called_once()
pipeline.error_to_stream_response.assert_called_once()
def test_ping_event(self, pipeline):
"""Test handling of ping events."""
# Setup
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
pipeline.ping_stream_response.assert_called_once()
def test_file_event(self, pipeline, mock_message_cycle_manager):
"""Test handling of file events."""
# Setup
file_event = Mock(spec=QueueMessageFileEvent)
file_event.message_file_id = "file-id"
mock_queue_message = Mock()
mock_queue_message.event = file_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
file_response = Mock(spec=MessageFileStreamResponse)
mock_message_cycle_manager.message_file_to_stream_response.return_value = file_response
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 1
assert responses[0] == file_response
mock_message_cycle_manager.message_file_to_stream_response.assert_called_once_with(file_event)
def test_publisher_is_called_with_messages(self, pipeline):
"""Test that publisher publishes messages when provided."""
# Setup
publisher = Mock(spec=AppGeneratorTTSPublisher)
ping_event = Mock(spec=QueuePingEvent)
mock_queue_message = Mock()
mock_queue_message.event = ping_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
list(pipeline._process_stream_response(publisher=publisher, trace_manager=None))
# Assert
# Called once with message and once with None at the end
assert publisher.publish.call_count == 2
publisher.publish.assert_any_call(mock_queue_message)
publisher.publish.assert_any_call(None)
def test_trace_manager_passed_to_save_message(self, pipeline):
"""Test that trace manager is passed to _save_message."""
# Setup
trace_manager = Mock(spec=TraceQueueManager)
message_end_event = Mock(spec=QueueMessageEndEvent)
message_end_event.llm_result = None
mock_queue_message = Mock()
mock_queue_message.event = message_end_event
pipeline.queue_manager.listen.return_value = [mock_queue_message]
pipeline._save_message = Mock()
pipeline._message_end_to_stream_response = Mock(return_value=Mock(spec=MessageEndStreamResponse))
# Patch db.engine used inside pipeline for session creation
with patch(
"core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db", new=SimpleNamespace(engine=Mock())
):
# Execute
list(pipeline._process_stream_response(publisher=None, trace_manager=trace_manager))
# Assert
pipeline._save_message.assert_called_once_with(session=ANY, trace_manager=trace_manager)
def test_multiple_events_sequence(self, pipeline, mock_message_cycle_manager, mock_task_state):
"""Test handling multiple events in sequence."""
# Setup
chunk1 = Mock()
chunk1.delta.message.content = "Hello"
chunk1.prompt_messages = []
chunk2 = Mock()
chunk2.delta.message.content = " world!"
chunk2.prompt_messages = []
llm_chunk_event1 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event1.chunk = chunk1
ping_event = Mock(spec=QueuePingEvent)
llm_chunk_event2 = Mock(spec=QueueLLMChunkEvent)
llm_chunk_event2.chunk = chunk2
mock_queue_messages = [
Mock(event=llm_chunk_event1),
Mock(event=ping_event),
Mock(event=llm_chunk_event2),
]
pipeline.queue_manager.listen.return_value = mock_queue_messages
mock_message_cycle_manager.get_message_event_type.return_value = StreamEvent.MESSAGE
pipeline.ping_stream_response = Mock(return_value=Mock(spec=PingStreamResponse))
# Execute
responses = list(pipeline._process_stream_response(publisher=None, trace_manager=None))
# Assert
assert len(responses) == 3
assert mock_task_state.llm_result.message.content == "Hello world!"
# Verify calls to message_to_stream_response
assert mock_message_cycle_manager.message_to_stream_response.call_count == 2
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer="Hello", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
mock_message_cycle_manager.message_to_stream_response.assert_any_call(
answer=" world!", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)

View File

@@ -1,166 +0,0 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
import pytest
from flask import current_app
from core.app.entities.task_entities import MessageStreamResponse, StreamEvent
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
class TestMessageCycleManagerOptimization:
"""Test cases for the message cycle manager optimization that prevents N+1 queries."""
@pytest.fixture
def mock_application_generate_entity(self):
"""Create a mock application generate entity."""
entity = Mock()
entity.task_id = "test-task-id"
return entity
@pytest.fixture
def message_cycle_manager(self, mock_application_generate_entity):
"""Create a message cycle manager instance."""
task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=event_type
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
)
# Assert
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
result = message_cycle_manager.message_to_stream_response(
answer="Hello world",
message_id="test-message-id",
from_variable_selector=["var1", "var2"],
event_type=StreamEvent.MESSAGE,
)
assert isinstance(result, MessageStreamResponse)
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.from_variable_selector == ["var1", "var2"]
assert result.event == StreamEvent.MESSAGE
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
)
chunk2_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE
assert chunk1_response.answer == "Chunk 1"
assert chunk2_response.answer == "Chunk 2"

View File

@@ -901,13 +901,6 @@ class TestFixedRecursiveCharacterTextSplitter:
# Verify no empty chunks
assert all(len(chunk) > 0 for chunk in result)
def test_double_slash_n(self):
data = "chunk 1\n\nsubchunk 1.\nsubchunk 2.\n\n---\n\nchunk 2\n\nsubchunk 1\nsubchunk 2."
separator = "\\n\\n---\\n\\n"
splitter = FixedRecursiveCharacterTextSplitter(fixed_separator=separator)
chunks = splitter.split_text(data)
assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."]
# ============================================================================
# Test Metadata Preservation

View File

@@ -1,4 +1,3 @@
import json
import time
import pytest
@@ -47,16 +46,14 @@ def make_start_node(user_inputs, variables):
def test_json_object_valid_schema():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
)
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age"],
}
variables = [
VariableEntity(
@@ -68,7 +65,7 @@ def test_json_object_valid_schema():
)
]
user_inputs = {"profile": json.dumps({"age": 20, "name": "Tom"})}
user_inputs = {"profile": {"age": 20, "name": "Tom"}}
node = make_start_node(user_inputs, variables)
result = node._run()
@@ -77,23 +74,12 @@ def test_json_object_valid_schema():
def test_json_object_invalid_json_string():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
json_schema=schema,
)
]
@@ -102,21 +88,38 @@ def test_json_object_invalid_json_string():
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match='{"age": 20, "name": "Tom" must be a valid JSON object'):
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
@pytest.mark.parametrize("value", ["[1, 2, 3]", "123"])
def test_json_object_valid_json_but_not_object(value):
variables = [
VariableEntity(
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
)
]
user_inputs = {"profile": value}
node = make_start_node(user_inputs, variables)
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()
def test_json_object_does_not_match_schema():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
variables = [
VariableEntity(
@@ -129,7 +132,7 @@ def test_json_object_does_not_match_schema():
]
# age is a string, which violates the schema (expects number)
user_inputs = {"profile": json.dumps({"age": "twenty", "name": "Tom"})}
user_inputs = {"profile": {"age": "twenty", "name": "Tom"}}
node = make_start_node(user_inputs, variables)
@@ -138,16 +141,14 @@ def test_json_object_does_not_match_schema():
def test_json_object_missing_required_schema_field():
schema = json.dumps(
{
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
)
schema = {
"type": "object",
"properties": {
"age": {"type": "number"},
"name": {"type": "string"},
},
"required": ["age", "name"],
}
variables = [
VariableEntity(
@@ -160,7 +161,7 @@ def test_json_object_missing_required_schema_field():
]
# Missing required field "name"
user_inputs = {"profile": json.dumps({"age": 20})}
user_inputs = {"profile": {"age": 20}}
node = make_start_node(user_inputs, variables)
@@ -213,7 +214,7 @@ def test_json_object_optional_variable_not_provided():
variable="profile",
label="profile",
type=VariableEntityType.JSON_OBJECT,
required=True,
required=False,
)
]
@@ -222,5 +223,5 @@ def test_json_object_optional_variable_not_provided():
node = make_start_node(user_inputs, variables)
# Current implementation raises a validation error even when the variable is optional
with pytest.raises(ValueError, match="profile is required in input form"):
with pytest.raises(ValueError, match="profile must be a JSON object"):
node._run()

View File

@@ -0,0 +1,588 @@
"""
Unit tests for SandboxMessagesCleanService.
This module tests parameter validation, method invocation, and error handling
without database dependencies (using mocks).
"""
import datetime
from unittest.mock import patch
import pytest
from enums.cloud_plan import CloudPlan
from services.sandbox_messages_clean_service import SandboxMessagesCleanService
class MockMessage:
"""Mock message object for testing."""
def __init__(self, id: str, app_id: str, created_at: datetime.datetime | None = None):
self.id = id
self.app_id = app_id
self.created_at = created_at or datetime.datetime.now()
class TestFilterExpiredSandboxMessages:
"""Unit tests for _filter_expired_sandbox_messages method."""
def test_filter_missing_tenant_mapping(self):
"""Test that messages with missing app-to-tenant mapping are excluded."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
]
app_to_tenant = {} # No mapping
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result == []
def test_filter_missing_tenant_plan(self):
"""Test that messages with missing tenant plan are excluded."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
}
tenant_plans = {} # No plans
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result == []
def test_filter_no_previous_subscription(self):
"""Test that messages with no previous subscription (expiration_date=-1) are deleted."""
# Arrange
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
MockMessage("msg3", "app3"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert - all messages should be deleted
assert set(result) == {"msg1", "msg2", "msg3"}
def test_filter_all_within_grace_period(self):
"""Test that no messages are deleted when all are within grace period."""
# Arrange
now = 1000000
# All expired recently (within 8 day grace period)
expired_1_day_ago = now - (1 * 24 * 60 * 60)
expired_3_days_ago = now - (3 * 24 * 60 * 60)
expired_7_days_ago = now - (7 * 24 * 60 * 60)
messages = [
MockMessage("msg1", "app1"),
MockMessage("msg2", "app2"),
MockMessage("msg3", "app3"),
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_1_day_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_3_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_7_days_ago},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=now,
)
# Assert - no messages should be deleted
assert result == []
def test_filter_partial_expired_beyond_grace_period(self):
"""Test filtering when some messages expired beyond grace period."""
# Arrange
now = 1000000
graceful_period = 8
# Different expiration scenarios
expired_5_days_ago = now - (5 * 24 * 60 * 60) # Within grace - keep
expired_10_days_ago = now - (10 * 24 * 60 * 60) # Beyond grace - delete
expired_30_days_ago = now - (30 * 24 * 60 * 60) # Beyond grace - delete
expired_exactly_8_days_ago = now - (8 * 24 * 60 * 60) # Exactly at boundary - keep
expired_9_days_ago = now - (9 * 24 * 60 * 60) # Just beyond - delete
messages = [
MockMessage("msg1", "app1"), # Within grace
MockMessage("msg2", "app2"), # Beyond grace
MockMessage("msg3", "app3"), # Beyond grace
MockMessage("msg4", "app4"), # No subscription - delete
MockMessage("msg5", "app5"), # Exactly at boundary
MockMessage("msg6", "app6"), # Just beyond grace
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app5": "tenant5",
"app6": "tenant6",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_5_days_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_10_days_ago},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_30_days_ago},
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant5": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_exactly_8_days_ago},
"tenant6": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_9_days_ago},
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=graceful_period,
current_timestamp=now,
)
# Assert - msg2, msg3, msg4, msg6 should be deleted
# msg1 and msg5 are within/at grace period boundary
assert set(result) == {"msg2", "msg3", "msg4", "msg6"}
def test_filter_complex_mixed_scenario(self):
"""Test complex scenario with mixed plans, expirations, and missing mappings."""
# Arrange
now = 1000000
sandbox_expired_old = now - (15 * 24 * 60 * 60) # 15 days ago - beyond grace
sandbox_expired_recent = now - (3 * 24 * 60 * 60) # 3 days ago - within grace
future_expiration = now + (30 * 24 * 60 * 60) # 30 days in future - active paid plan
messages = [
MockMessage("msg1", "app1"), # Sandbox, no subscription - delete
MockMessage("msg2", "app2"), # Sandbox, expired old - delete
MockMessage("msg3", "app3"), # Sandbox, within grace - keep
MockMessage("msg4", "app4"), # Team plan, active - keep
MockMessage("msg5", "app5"), # No tenant mapping - keep
MockMessage("msg6", "app6"), # No plan info - keep
MockMessage("msg7", "app7"), # Sandbox, expired old - delete
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
"app6": "tenant6", # Has mapping but no plan
"app7": "tenant7",
# app5 has no mapping
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_recent},
"tenant4": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant7": {"plan": CloudPlan.SANDBOX, "expiration_date": sandbox_expired_old},
# tenant6 has no plan
}
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only sandbox expired beyond grace period and no subscription
assert set(result) == {"msg1", "msg2", "msg7"}
def test_filter_empty_inputs(self):
"""Test filtering with empty inputs returns empty list."""
# Arrange - empty messages
result1 = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=[],
app_to_tenant={"app1": "tenant1"},
tenant_plans={"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}},
tenant_whitelist=[],
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert
assert result1 == []
def test_filter_uses_default_timestamp(self):
"""Test that method uses current time when timestamp not provided."""
# Arrange
messages = [MockMessage("msg1", "app1")]
app_to_tenant = {"app1": "tenant1"}
tenant_plans = {"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}}
# Act - don't provide current_timestamp
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=[],
graceful_period_days=8,
# current_timestamp not provided - should use datetime.now()
)
# Assert - should still work and return msg1 (no subscription)
assert result == ["msg1"]
def test_filter_with_whitelist(self):
"""Test that messages from whitelisted tenants are excluded from deletion."""
# Arrange
messages = [
MockMessage("msg1", "app1"), # Whitelisted tenant - should be kept
MockMessage("msg2", "app2"), # Not whitelisted - should be deleted
MockMessage("msg3", "app3"), # Whitelisted tenant - should be kept
MockMessage("msg4", "app4"), # Not whitelisted - should be deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
"app4": "tenant4",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant4": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
tenant_whitelist = ["tenant1", "tenant3"] # Whitelist tenant1 and tenant3
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=1000000,
)
# Assert - only msg2 and msg4 should be deleted (not whitelisted)
assert set(result) == {"msg2", "msg4"}
def test_filter_with_whitelist_and_grace_period(self):
"""Test whitelist takes precedence over grace period logic."""
# Arrange
now = 1000000
expired_long_ago = now - (30 * 24 * 60 * 60) # Expired 30 days ago
messages = [
MockMessage("msg1", "app1"), # Whitelisted, expired long ago - should be kept
MockMessage("msg2", "app2"), # Not whitelisted, expired long ago - should be deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
"tenant2": {"plan": CloudPlan.SANDBOX, "expiration_date": expired_long_ago},
}
tenant_whitelist = ["tenant1"] # Only tenant1 is whitelisted
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only msg2 should be deleted
assert result == ["msg2"]
def test_filter_whitelist_with_non_sandbox_plans(self):
"""Test that whitelist only affects sandbox plan messages."""
# Arrange
now = 1000000
future_expiration = now + (30 * 24 * 60 * 60)
messages = [
MockMessage("msg1", "app1"), # Sandbox, whitelisted - kept
MockMessage("msg2", "app2"), # Team plan, whitelisted - kept (not sandbox)
MockMessage("msg3", "app3"), # Sandbox, not whitelisted - deleted
]
app_to_tenant = {
"app1": "tenant1",
"app2": "tenant2",
"app3": "tenant3",
}
tenant_plans = {
"tenant1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
"tenant2": {"plan": CloudPlan.TEAM, "expiration_date": future_expiration},
"tenant3": {"plan": CloudPlan.SANDBOX, "expiration_date": -1},
}
tenant_whitelist = ["tenant1", "tenant2"]
# Act
result = SandboxMessagesCleanService._filter_expired_sandbox_messages(
messages=messages,
app_to_tenant=app_to_tenant,
tenant_plans=tenant_plans,
tenant_whitelist=tenant_whitelist,
graceful_period_days=8,
current_timestamp=now,
)
# Assert - only msg3 should be deleted (sandbox, not whitelisted)
assert result == ["msg3"]
class TestCleanSandboxMessagesByTimeRange:
"""Unit tests for clean_sandbox_messages_by_time_range method."""
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_valid_time_range_and_args(self, mock_clean):
"""Test with valid time range and other parameters."""
# Arrange
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 12, 31, 23, 59, 59)
batch_size = 500
dry_run = True
mock_clean.return_value = {
"batches": 5,
"total_messages": 100,
"total_deleted": 100,
}
# Act
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
# Assert, expected no exception raised
mock_clean.assert_called_once_with(
start_from=start_from,
end_before=end_before,
graceful_period=21,
batch_size=batch_size,
dry_run=dry_run,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_with_default_args(self, mock_clean):
"""Test with default args."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
mock_clean.return_value = {
"batches": 2,
"total_messages": 50,
"total_deleted": 0,
}
# Act
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
)
# Assert
mock_clean.assert_called_once_with(
start_from=start_from,
end_before=end_before,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
def test_invalid_time_range(self):
"""Test invalid time range raises ValueError."""
# Arrange
same_time = datetime.datetime(2024, 1, 1, 12, 0, 0)
# Act & Assert start equals end
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=same_time,
end_before=same_time,
)
# Arrange
start_from = datetime.datetime(2024, 12, 31)
end_before = datetime.datetime(2024, 1, 1)
# Act & Assert start after end
with pytest.raises(ValueError, match="start_from .* must be less than end_before"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
)
def test_invalid_batch_size(self):
"""Test invalid batch_size raises ValueError."""
# Arrange
start_from = datetime.datetime(2024, 1, 1)
end_before = datetime.datetime(2024, 2, 1)
# Act & Assert batch_size = 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=0,
)
# Act & Assert batch_size < 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_time_range(
start_from=start_from,
end_before=end_before,
batch_size=-100,
)
class TestCleanSandboxMessagesByDays:
"""Unit tests for clean_sandbox_messages_by_days method."""
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_default_days(self, mock_clean):
"""Test with default 30 days."""
# Arrange
mock_clean.return_value = {"batches": 3, "total_messages": 75, "total_deleted": 75}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
SandboxMessagesCleanService.clean_sandbox_messages_by_days()
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=30) # default days=30
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_custom_days(self, mock_clean):
"""Test with custom number of days."""
# Arrange
custom_days = 90
mock_clean.return_value = {"batches": 10, "total_messages": 500, "total_deleted": 500}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
result = SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=custom_days)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=custom_days)
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
@patch.object(SandboxMessagesCleanService, "_clean_sandbox_messages_by_time_range")
def test_zero_days(self, mock_clean):
"""Test with days=0 (clean all messages before now)."""
# Arrange
mock_clean.return_value = {"batches": 0, "total_messages": 0, "total_deleted": 0}
# Act
with patch("services.sandbox_messages_clean_service.datetime") as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta # Keep original timedelta
SandboxMessagesCleanService.clean_sandbox_messages_by_days(days=0)
# Assert
expected_end_before = fixed_now - datetime.timedelta(days=0) # same as fixed_now
mock_clean.assert_called_once_with(
end_before=expected_end_before,
start_from=None,
graceful_period=21,
batch_size=1000,
dry_run=False,
)
def test_invalid_batch_size(self):
"""Test invalid batch_size raises ValueError."""
# Act & Assert batch_size = 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=30,
batch_size=0,
)
# Act & Assert batch_size < 0
with pytest.raises(ValueError, match="batch_size .* must be greater than 0"):
SandboxMessagesCleanService.clean_sandbox_messages_by_days(
days=30,
batch_size=-500,
)

View File

@@ -1369,10 +1369,7 @@ PLUGIN_STDIO_BUFFER_SIZE=1024
PLUGIN_STDIO_MAX_BUFFER_SIZE=5242880
PLUGIN_PYTHON_ENV_INIT_TIMEOUT=120
# Plugin Daemon side timeout (configure to match the API side below)
PLUGIN_MAX_EXECUTION_TIMEOUT=600
# API side timeout (configure to match the Plugin Daemon side above)
PLUGIN_DAEMON_TIMEOUT=600.0
# PIP_MIRROR_URL=https://pypi.tuna.tsinghua.edu.cn/simple
PIP_MIRROR_URL=

View File

@@ -34,7 +34,6 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

View File

@@ -591,7 +591,6 @@ x-shared-env: &shared-api-worker-env
PLUGIN_STDIO_MAX_BUFFER_SIZE: ${PLUGIN_STDIO_MAX_BUFFER_SIZE:-5242880}
PLUGIN_PYTHON_ENV_INIT_TIMEOUT: ${PLUGIN_PYTHON_ENV_INIT_TIMEOUT:-120}
PLUGIN_MAX_EXECUTION_TIMEOUT: ${PLUGIN_MAX_EXECUTION_TIMEOUT:-600}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
PIP_MIRROR_URL: ${PIP_MIRROR_URL:-}
PLUGIN_STORAGE_TYPE: ${PLUGIN_STORAGE_TYPE:-local}
PLUGIN_STORAGE_LOCAL_ROOT: ${PLUGIN_STORAGE_LOCAL_ROOT:-/app/storage}
@@ -703,7 +702,6 @@ services:
PLUGIN_REMOTE_INSTALL_HOST: ${EXPOSE_PLUGIN_DEBUGGING_HOST:-localhost}
PLUGIN_REMOTE_INSTALL_PORT: ${EXPOSE_PLUGIN_DEBUGGING_PORT:-5003}
PLUGIN_MAX_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
PLUGIN_DAEMON_TIMEOUT: ${PLUGIN_DAEMON_TIMEOUT:-600.0}
INNER_API_KEY_FOR_PLUGIN: ${PLUGIN_DIFY_INNER_API_KEY:-QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1}
depends_on:
init_permissions:

View File

@@ -1,71 +0,0 @@
/**
* Mock for ky HTTP client
* This mock is used to avoid ESM issues in Jest tests
*/
type KyResponse = {
ok: boolean
status: number
statusText: string
headers: Headers
json: jest.Mock
text: jest.Mock
blob: jest.Mock
arrayBuffer: jest.Mock
clone: jest.Mock
}
type KyInstance = jest.Mock & {
get: jest.Mock
post: jest.Mock
put: jest.Mock
patch: jest.Mock
delete: jest.Mock
head: jest.Mock
create: jest.Mock
extend: jest.Mock
stop: symbol
}
const createResponse = (data: unknown = {}, status = 200): KyResponse => {
const response: KyResponse = {
ok: status >= 200 && status < 300,
status,
statusText: status === 200 ? 'OK' : 'Error',
headers: new Headers(),
json: jest.fn().mockResolvedValue(data),
text: jest.fn().mockResolvedValue(JSON.stringify(data)),
blob: jest.fn().mockResolvedValue(new Blob()),
arrayBuffer: jest.fn().mockResolvedValue(new ArrayBuffer(0)),
clone: jest.fn(),
}
// Ensure clone returns a new response-like object, not the same instance
response.clone.mockImplementation(() => createResponse(data, status))
return response
}
const createKyInstance = (): KyInstance => {
const instance = jest.fn().mockImplementation(() => Promise.resolve(createResponse())) as KyInstance
// HTTP methods
instance.get = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.post = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.put = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.patch = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.delete = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
instance.head = jest.fn().mockImplementation(() => Promise.resolve(createResponse()))
// Create new instance with custom options
instance.create = jest.fn().mockImplementation(() => createKyInstance())
instance.extend = jest.fn().mockImplementation(() => createKyInstance())
// Stop method for AbortController
instance.stop = Symbol('stop')
return instance
}
const ky = createKyInstance()
export default ky
export { ky }

View File

@@ -16,7 +16,7 @@ import {
import { useTranslation } from 'react-i18next'
import { useShallow } from 'zustand/react/shallow'
import s from './style.module.css'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useStore } from '@/app/components/app/store'
import AppSideBar from '@/app/components/app-sidebar'
import type { NavIcon } from '@/app/components/app-sidebar/navLink'

View File

@@ -3,7 +3,7 @@ import { RiCalendarLine } from '@remixicon/react'
import type { Dayjs } from 'dayjs'
import type { FC } from 'react'
import React, { useCallback } from 'react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { formatToLocalTime } from '@/utils/format'
import { useI18N } from '@/context/i18n'
import Picker from '@/app/components/base/date-and-time-picker/date-picker'

View File

@@ -6,7 +6,7 @@ import { SimpleSelect } from '@/app/components/base/select'
import type { Item } from '@/app/components/base/select'
import dayjs from 'dayjs'
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
const today = dayjs()

View File

@@ -4,7 +4,7 @@ import React, { useCallback, useRef, useState } from 'react'
import type { PopupProps } from './config-popup'
import ConfigPopup from './config-popup'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,

View File

@@ -12,7 +12,7 @@ import Indicator from '@/app/components/header/indicator'
import Switch from '@/app/components/base/switch'
import Tooltip from '@/app/components/base/tooltip'
import Divider from '@/app/components/base/divider'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
const I18N_PREFIX = 'app.tracing'

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Input from '@/app/components/base/input'
type Props = {

View File

@@ -12,7 +12,7 @@ import type { AliyunConfig, ArizeConfig, DatabricksConfig, LangFuseConfig, LangS
import { TracingProvider } from './type'
import TracingIcon from './tracing-icon'
import ConfigButton from './config-button'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { AliyunIcon, ArizeIcon, DatabricksIcon, LangfuseIcon, LangsmithIcon, MlflowIcon, OpikIcon, PhoenixIcon, TencentIcon, WeaveIcon } from '@/app/components/base/icons/src/public/tracing'
import Indicator from '@/app/components/header/indicator'
import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps'

View File

@@ -6,7 +6,7 @@ import {
} from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { TracingProvider } from './type'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { AliyunIconBig, ArizeIconBig, DatabricksIconBig, LangfuseIconBig, LangsmithIconBig, MlflowIconBig, OpikIconBig, PhoenixIconBig, TencentIconBig, WeaveIconBig } from '@/app/components/base/icons/src/public/tracing'
import { Eye as View } from '@/app/components/base/icons/src/vender/solid/general'

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { TracingIcon as Icon } from '@/app/components/base/icons/src/public/tracing'
type Props = {

View File

@@ -23,7 +23,7 @@ import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use
import useDocumentTitle from '@/hooks/use-document-title'
import ExtraInfo from '@/app/components/datasets/extra-info'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
export type IAppDetailLayoutProps = {
children: React.ReactNode

View File

@@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
export default function SignInLayout({ children }: any) {

View File

@@ -2,7 +2,7 @@
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useRouter, useSearchParams } from 'next/navigation'
import { cn } from '@/utils/classnames'
import cn from 'classnames'
import { RiCheckboxCircleFill } from '@remixicon/react'
import { useCountDown } from 'ahooks'
import Button from '@/app/components/base/button'

View File

@@ -1,6 +1,6 @@
'use client'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import type { PropsWithChildren } from 'react'

View File

@@ -7,7 +7,7 @@ import Loading from '@/app/components/base/loading'
import MailAndCodeAuth from './components/mail-and-code-auth'
import MailAndPasswordAuth from './components/mail-and-password-auth'
import SSOAuth from './components/sso-auth'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { LicenseStatus } from '@/types/feature'
import { IS_CE_EDITION } from '@/config'
import { useGlobalPublicStore } from '@/context/global-public-context'

View File

@@ -1,7 +1,7 @@
'use client'
import Header from '@/app/signin/_header'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
import useDocumentTitle from '@/hooks/use-document-title'
import { AppContextProvider } from '@/context/app-context'

View File

@@ -1,13 +1,13 @@
'use client'
import { useEffect } from 'react'
import { useTranslation } from 'react-i18next'
import useSWR from 'swr'
import { useRouter, useSearchParams } from 'next/navigation'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Button from '@/app/components/base/button'
import { invitationCheck } from '@/service/common'
import Loading from '@/app/components/base/loading'
import useDocumentTitle from '@/hooks/use-document-title'
import { useInvitationCheck } from '@/service/use-common'
const ActivateForm = () => {
useDocumentTitle('')
@@ -26,21 +26,19 @@ const ActivateForm = () => {
token,
},
}
const { data: checkRes } = useInvitationCheck({
...checkParams.params,
token: token || undefined,
}, true)
useEffect(() => {
if (checkRes?.is_valid) {
const params = new URLSearchParams(searchParams)
const { email, workspace_id } = checkRes.data
params.set('email', encodeURIComponent(email))
params.set('workspace_id', encodeURIComponent(workspace_id))
params.set('invite_token', encodeURIComponent(token as string))
router.replace(`/signin?${params.toString()}`)
}
}, [checkRes, router, searchParams, token])
const { data: checkRes } = useSWR(checkParams, invitationCheck, {
revalidateOnFocus: false,
onSuccess(data) {
if (data.is_valid) {
const params = new URLSearchParams(searchParams)
const { email, workspace_id } = data.data
params.set('email', encodeURIComponent(email))
params.set('workspace_id', encodeURIComponent(workspace_id))
params.set('invite_token', encodeURIComponent(token as string))
router.replace(`/signin?${params.toString()}`)
}
},
})
return (
<div className={

View File

@@ -2,7 +2,7 @@
import React from 'react'
import Header from '../signin/_header'
import ActivateForm from './activateForm'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
const Activate = () => {

View File

@@ -29,7 +29,7 @@ import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overvie
import type { Operation } from './app-operations'
import AppOperations from './app-operations'
import dynamic from 'next/dynamic'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), {

View File

@@ -16,7 +16,7 @@ import AppInfo from './app-info'
import NavLink from './navLink'
import { useStore as useAppStore } from '@/app/components/app/store'
import type { NavIcon } from './navLink'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { AppModeEnum } from '@/types/app'
type Props = {

View File

@@ -2,7 +2,7 @@ import React, { useCallback, useState } from 'react'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
import ActionButton from '../../base/action-button'
import { RiMoreFill } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Menu from './menu'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'

View File

@@ -1,379 +0,0 @@
import React from 'react'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import DatasetInfo from './index'
import Dropdown from './dropdown'
import Menu from './menu'
import MenuItem from './menu-item'
import type { DataSet } from '@/models/datasets'
import {
ChunkingMode,
DataSourceType,
DatasetPermission,
} from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import { RiEditLine } from '@remixicon/react'
let mockDataset: DataSet
let mockIsDatasetOperator = false
const mockReplace = jest.fn()
const mockInvalidDatasetList = jest.fn()
const mockInvalidDatasetDetail = jest.fn()
const mockExportPipeline = jest.fn()
const mockCheckIsUsedInApp = jest.fn()
const mockDeleteDataset = jest.fn()
const createDataset = (overrides: Partial<DataSet> = {}): DataSet => ({
id: 'dataset-1',
name: 'Dataset Name',
indexing_status: 'completed',
icon_info: {
icon: '📙',
icon_background: '#FFF4ED',
icon_type: 'emoji',
icon_url: '',
},
description: 'Dataset description',
permission: DatasetPermission.onlyMe,
data_source_type: DataSourceType.FILE,
indexing_technique: 'high_quality' as DataSet['indexing_technique'],
created_by: 'user-1',
updated_by: 'user-1',
updated_at: 1690000000,
app_count: 0,
doc_form: ChunkingMode.text,
document_count: 1,
total_document_count: 1,
word_count: 1000,
provider: 'internal',
embedding_model: 'text-embedding-3',
embedding_model_provider: 'openai',
embedding_available: true,
retrieval_model_dict: {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 5,
score_threshold_enabled: false,
score_threshold: 0,
},
retrieval_model: {
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 5,
score_threshold_enabled: false,
score_threshold: 0,
},
tags: [],
external_knowledge_info: {
external_knowledge_id: '',
external_knowledge_api_id: '',
external_knowledge_api_name: '',
external_knowledge_api_endpoint: '',
},
external_retrieval_model: {
top_k: 0,
score_threshold: 0,
score_threshold_enabled: false,
},
built_in_field_enabled: false,
runtime_mode: 'rag_pipeline',
enable_api: false,
is_multimodal: false,
...overrides,
})
jest.mock('next/navigation', () => ({
useRouter: () => ({
replace: mockReplace,
}),
}))
jest.mock('@/context/dataset-detail', () => ({
useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }),
}))
jest.mock('@/context/app-context', () => ({
useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) =>
selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }),
}))
jest.mock('@/service/knowledge/use-dataset', () => ({
datasetDetailQueryKeyPrefix: ['dataset', 'detail'],
useInvalidDatasetList: () => mockInvalidDatasetList,
}))
jest.mock('@/service/use-base', () => ({
useInvalid: () => mockInvalidDatasetDetail,
}))
jest.mock('@/service/use-pipeline', () => ({
useExportPipelineDSL: () => ({
mutateAsync: mockExportPipeline,
}),
}))
jest.mock('@/service/datasets', () => ({
checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args),
deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args),
}))
jest.mock('@/hooks/use-knowledge', () => ({
useKnowledge: () => ({
formatIndexingTechniqueAndMethod: () => 'indexing-technique',
}),
}))
jest.mock('@/app/components/datasets/rename-modal', () => ({
__esModule: true,
default: ({
show,
onClose,
onSuccess,
}: {
show: boolean
onClose: () => void
onSuccess?: () => void
}) => {
if (!show)
return null
return (
<div data-testid="rename-modal">
<button type="button" onClick={onSuccess}>Success</button>
<button type="button" onClick={onClose}>Close</button>
</div>
)
},
}))
const openMenu = async (user: ReturnType<typeof userEvent.setup>) => {
const trigger = screen.getByRole('button')
await user.click(trigger)
}
describe('DatasetInfo', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset()
mockIsDatasetOperator = false
})
// Rendering of dataset summary details based on expand and dataset state.
describe('Rendering', () => {
it('should show dataset details when expanded', () => {
// Arrange
mockDataset = createDataset({ is_published: true })
render(<DatasetInfo expand />)
// Assert
expect(screen.getByText('Dataset Name')).toBeInTheDocument()
expect(screen.getByText('Dataset description')).toBeInTheDocument()
expect(screen.getByText('dataset.chunkingMode.general')).toBeInTheDocument()
expect(screen.getByText('indexing-technique')).toBeInTheDocument()
})
it('should show external tag when provider is external', () => {
// Arrange
mockDataset = createDataset({ provider: 'external', is_published: false })
render(<DatasetInfo expand />)
// Assert
expect(screen.getByText('dataset.externalTag')).toBeInTheDocument()
expect(screen.queryByText('dataset.chunkingMode.general')).not.toBeInTheDocument()
})
it('should hide detailed fields when collapsed', () => {
// Arrange
render(<DatasetInfo expand={false} />)
// Assert
expect(screen.queryByText('Dataset Name')).not.toBeInTheDocument()
expect(screen.queryByText('Dataset description')).not.toBeInTheDocument()
})
})
})
describe('MenuItem', () => {
beforeEach(() => {
jest.clearAllMocks()
})
// Event handling for menu item interactions.
describe('Interactions', () => {
it('should call handler when clicked', async () => {
const user = userEvent.setup()
const handleClick = jest.fn()
// Arrange
render(<MenuItem name="Edit" Icon={RiEditLine} handleClick={handleClick} />)
// Act
await user.click(screen.getByText('Edit'))
// Assert
expect(handleClick).toHaveBeenCalledTimes(1)
})
})
})
describe('Menu', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset()
})
// Rendering of menu options based on runtime mode and delete visibility.
describe('Rendering', () => {
it('should show edit, export, and delete options when rag pipeline and deletable', () => {
// Arrange
mockDataset = createDataset({ runtime_mode: 'rag_pipeline' })
render(
<Menu
showDelete
openRenameModal={jest.fn()}
handleExportPipeline={jest.fn()}
detectIsUsedByApp={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
expect(screen.getByText('datasetPipeline.operations.exportPipeline')).toBeInTheDocument()
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
})
it('should hide export and delete options when not rag pipeline and not deletable', () => {
// Arrange
mockDataset = createDataset({ runtime_mode: 'general' })
render(
<Menu
showDelete={false}
openRenameModal={jest.fn()}
handleExportPipeline={jest.fn()}
detectIsUsedByApp={jest.fn()}
/>,
)
// Assert
expect(screen.getByText('common.operation.edit')).toBeInTheDocument()
expect(screen.queryByText('datasetPipeline.operations.exportPipeline')).not.toBeInTheDocument()
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
})
})
})
describe('Dropdown', () => {
beforeEach(() => {
jest.clearAllMocks()
mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' })
mockIsDatasetOperator = false
mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' })
mockCheckIsUsedInApp.mockResolvedValue({ is_using: false })
mockDeleteDataset.mockResolvedValue({})
if (!('createObjectURL' in URL)) {
Object.defineProperty(URL, 'createObjectURL', {
value: jest.fn(),
writable: true,
})
}
if (!('revokeObjectURL' in URL)) {
Object.defineProperty(URL, 'revokeObjectURL', {
value: jest.fn(),
writable: true,
})
}
})
// Rendering behavior based on workspace role.
describe('Rendering', () => {
it('should hide delete option when user is dataset operator', async () => {
const user = userEvent.setup()
// Arrange
mockIsDatasetOperator = true
render(<Dropdown expand />)
// Act
await openMenu(user)
// Assert
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
})
})
// User interactions that trigger modals and exports.
describe('Interactions', () => {
it('should open rename modal when edit is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.edit'))
// Assert
expect(screen.getByTestId('rename-modal')).toBeInTheDocument()
})
it('should export pipeline when export is clicked', async () => {
const user = userEvent.setup()
const anchorClickSpy = jest.spyOn(HTMLAnchorElement.prototype, 'click')
const createObjectURLSpy = jest.spyOn(URL, 'createObjectURL')
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('datasetPipeline.operations.exportPipeline'))
// Assert
await waitFor(() => {
expect(mockExportPipeline).toHaveBeenCalledWith({
pipelineId: 'pipeline-1',
include: false,
})
})
expect(createObjectURLSpy).toHaveBeenCalledTimes(1)
expect(anchorClickSpy).toHaveBeenCalledTimes(1)
})
it('should show delete confirmation when delete is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.delete'))
// Assert
await waitFor(() => {
expect(screen.getByText('dataset.deleteDatasetConfirmContent')).toBeInTheDocument()
})
})
it('should delete dataset and redirect when confirm is clicked', async () => {
const user = userEvent.setup()
// Arrange
render(<Dropdown expand />)
// Act
await openMenu(user)
await user.click(screen.getByText('common.operation.delete'))
await user.click(await screen.findByRole('button', { name: 'common.operation.confirm' }))
// Assert
await waitFor(() => {
expect(mockDeleteDataset).toHaveBeenCalledWith('dataset-1')
})
expect(mockInvalidDatasetList).toHaveBeenCalledTimes(1)
expect(mockReplace).toHaveBeenCalledWith('/datasets')
})
})
})

View File

@@ -8,7 +8,7 @@ import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import type { DataSet } from '@/models/datasets'
import { DOC_FORM_TEXT } from '@/models/datasets'
import { useKnowledge } from '@/hooks/use-knowledge'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Dropdown from './dropdown'
type DatasetInfoProps = {

View File

@@ -11,7 +11,7 @@ import AppIcon from '../base/app-icon'
import Divider from '../base/divider'
import NavLink from './navLink'
import type { NavIcon } from './navLink'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import Effect from '../base/effect'
import Dropdown from './dataset-info/dropdown'

View File

@@ -9,7 +9,7 @@ import AppSidebarDropdown from './app-sidebar-dropdown'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import { useStore as useAppStore } from '@/app/components/app/store'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Divider from '../base/divider'
import { useHover, useKeyPress } from 'ahooks'
import ToggleButton from './toggle-button'

View File

@@ -2,7 +2,7 @@
import React from 'react'
import { useSelectedLayoutSegment } from 'next/navigation'
import Link from 'next/link'
import { cn } from '@/utils/classnames'
import classNames from '@/utils/classnames'
import type { RemixiconComponentType } from '@remixicon/react'
export type NavIcon = React.ComponentType<
@@ -42,7 +42,7 @@ const NavLink = ({
const NavIcon = isActive ? iconMap.selected : iconMap.normal
const renderIcon = () => (
<div className={cn(mode !== 'expand' && '-ml-1')}>
<div className={classNames(mode !== 'expand' && '-ml-1')}>
<NavIcon className="h-4 w-4 shrink-0" aria-hidden="true" />
</div>
)
@@ -53,17 +53,21 @@ const NavLink = ({
key={name}
type='button'
disabled
className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1')}
className={classNames(
'system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover',
'pl-3 pr-1',
)}
title={mode === 'collapse' ? name : ''}
aria-disabled
>
{renderIcon()}
<span
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0')}
: 'ml-0 max-w-0 opacity-0',
)}
>
{name}
</span>
@@ -75,18 +79,22 @@ const NavLink = ({
<Link
key={name}
href={href}
className={cn(isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1')}
className={classNames(
isActive
? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only'
: 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover',
'flex h-8 items-center rounded-lg pl-3 pr-1',
)}
title={mode === 'collapse' ? name : ''}
>
{renderIcon()}
<span
className={cn('overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
className={classNames(
'overflow-hidden whitespace-nowrap transition-all duration-200 ease-in-out',
mode === 'expand'
? 'ml-2 max-w-none opacity-100'
: 'ml-0 max-w-0 opacity-0')}
: 'ml-0 max-w-0 opacity-0',
)}
>
{name}
</span>

View File

@@ -1,7 +1,7 @@
import React from 'react'
import Button from '../base/button'
import { RiArrowLeftSLine, RiArrowRightSLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Tooltip from '../base/tooltip'
import { useTranslation } from 'react-i18next'
import { getKeyboardKeyNameBySystem } from '../workflow/utils'

View File

@@ -1,42 +0,0 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import BatchAction from './batch-action'
describe('BatchAction', () => {
const baseProps = {
selectedIds: ['1', '2', '3'],
onBatchDelete: jest.fn(),
onCancel: jest.fn(),
}
beforeEach(() => {
jest.clearAllMocks()
})
it('should show the selected count and trigger cancel action', () => {
render(<BatchAction {...baseProps} className='custom-class' />)
expect(screen.getByText('3')).toBeInTheDocument()
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
fireEvent.click(screen.getByRole('button', { name: 'common.operation.cancel' }))
expect(baseProps.onCancel).toHaveBeenCalledTimes(1)
})
it('should confirm before running batch delete', async () => {
const onBatchDelete = jest.fn().mockResolvedValue(undefined)
render(<BatchAction {...baseProps} onBatchDelete={onBatchDelete} />)
fireEvent.click(screen.getByRole('button', { name: 'common.operation.delete' }))
await screen.findByText('appAnnotation.list.delete.title')
await act(async () => {
fireEvent.click(screen.getAllByRole('button', { name: 'common.operation.delete' })[1])
})
await waitFor(() => {
expect(onBatchDelete).toHaveBeenCalledTimes(1)
})
})
})

View File

@@ -3,7 +3,7 @@ import { RiDeleteBinLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
import { useBoolean } from 'ahooks'
import Divider from '@/app/components/base/divider'
import { cn } from '@/utils/classnames'
import classNames from '@/utils/classnames'
import Confirm from '@/app/components/base/confirm'
const i18nPrefix = 'appAnnotation.batchAction'
@@ -38,7 +38,7 @@ const BatchAction: FC<IBatchActionProps> = ({
setIsNotDeleting()
}
return (
<div className={cn('pointer-events-none flex w-full justify-center', className)}>
<div className={classNames('pointer-events-none flex w-full justify-center', className)}>
<div className='pointer-events-auto flex items-center gap-x-1 rounded-[10px] border border-components-actionbar-border-accent bg-components-actionbar-bg-accent p-1 shadow-xl shadow-shadow-shadow-5 backdrop-blur-[5px]'>
<div className='inline-flex items-center gap-x-2 py-1 pl-2 pr-3'>
<span className='flex h-5 w-5 items-center justify-center rounded-md bg-text-accent px-1 py-0.5 text-xs font-medium text-text-primary-on-surface'>

View File

@@ -1,72 +0,0 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import CSVDownload from './csv-downloader'
import I18nContext from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import type { Locale } from '@/i18n-config'
const downloaderProps: any[] = []
jest.mock('react-papaparse', () => ({
useCSVDownloader: jest.fn(() => ({
CSVDownloader: ({ children, ...props }: any) => {
downloaderProps.push(props)
return <div data-testid="mock-csv-downloader">{children}</div>
},
Type: { Link: 'link' },
})),
}))
const renderWithLocale = (locale: Locale) => {
return render(
<I18nContext.Provider value={{
locale,
i18n: {},
setLocaleOnClient: jest.fn().mockResolvedValue(undefined),
}}
>
<CSVDownload />
</I18nContext.Provider>,
)
}
describe('CSVDownload', () => {
const englishTemplate = [
['question', 'answer'],
['question1', 'answer1'],
['question2', 'answer2'],
]
const chineseTemplate = [
['问题', '答案'],
['问题 1', '答案 1'],
['问题 2', '答案 2'],
]
beforeEach(() => {
downloaderProps.length = 0
})
it('should render the structure preview and pass English template data by default', () => {
renderWithLocale('en-US' as Locale)
expect(screen.getByText('share.generation.csvStructureTitle')).toBeInTheDocument()
expect(screen.getByText('appAnnotation.batchModal.template')).toBeInTheDocument()
expect(downloaderProps[0]).toMatchObject({
filename: 'template-en-US',
type: 'link',
bom: true,
data: englishTemplate,
})
})
it('should switch to the Chinese template when locale matches the secondary language', () => {
const locale = LanguagesSupported[1] as Locale
renderWithLocale(locale)
expect(downloaderProps[0]).toMatchObject({
filename: `template-${locale}`,
data: chineseTemplate,
})
})
})

View File

@@ -4,7 +4,7 @@ import React, { useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import { RiDeleteBinLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files'
import { ToastContext } from '@/app/components/base/toast'
import Button from '@/app/components/base/button'

View File

@@ -1,164 +0,0 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import BatchModal, { ProcessStatus } from './index'
import { useProviderContext } from '@/context/provider-context'
import { annotationBatchImport, checkAnnotationBatchImportProgress } from '@/service/annotation'
import type { IBatchModalProps } from './index'
import Toast from '@/app/components/base/toast'
jest.mock('@/app/components/base/toast', () => ({
__esModule: true,
default: {
notify: jest.fn(),
},
}))
jest.mock('@/service/annotation', () => ({
annotationBatchImport: jest.fn(),
checkAnnotationBatchImportProgress: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
jest.mock('./csv-downloader', () => ({
__esModule: true,
default: () => <div data-testid="csv-downloader-stub" />,
}))
let lastUploadedFile: File | undefined
jest.mock('./csv-uploader', () => ({
__esModule: true,
default: ({ file, updateFile }: { file?: File; updateFile: (file?: File) => void }) => (
<div>
<button
data-testid="mock-uploader"
onClick={() => {
lastUploadedFile = new File(['question,answer'], 'batch.csv', { type: 'text/csv' })
updateFile(lastUploadedFile)
}}
>
upload
</button>
{file && <span data-testid="selected-file">{file.name}</span>}
</div>
),
}))
jest.mock('@/app/components/billing/annotation-full', () => ({
__esModule: true,
default: () => <div data-testid="annotation-full" />,
}))
const mockNotify = Toast.notify as jest.Mock
const useProviderContextMock = useProviderContext as jest.Mock
const annotationBatchImportMock = annotationBatchImport as jest.Mock
const checkAnnotationBatchImportProgressMock = checkAnnotationBatchImportProgress as jest.Mock
const renderComponent = (props: Partial<IBatchModalProps> = {}) => {
const mergedProps: IBatchModalProps = {
appId: 'app-id',
isShow: true,
onCancel: jest.fn(),
onAdded: jest.fn(),
...props,
}
return {
...render(<BatchModal {...mergedProps} />),
props: mergedProps,
}
}
describe('BatchModal', () => {
beforeEach(() => {
jest.clearAllMocks()
lastUploadedFile = undefined
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 0 },
total: { annotatedResponse: 10 },
},
enableBilling: false,
})
})
it('should disable run action and show billing hint when annotation quota is full', () => {
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 10 },
total: { annotatedResponse: 10 },
},
enableBilling: true,
})
renderComponent()
expect(screen.getByTestId('annotation-full')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })).toBeDisabled()
})
it('should reset uploader state when modal closes and allow manual cancellation', () => {
const { rerender, props } = renderComponent()
fireEvent.click(screen.getByTestId('mock-uploader'))
expect(screen.getByTestId('selected-file')).toHaveTextContent('batch.csv')
rerender(<BatchModal {...props} isShow={false} />)
rerender(<BatchModal {...props} isShow />)
expect(screen.queryByTestId('selected-file')).toBeNull()
fireEvent.click(screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }))
expect(props.onCancel).toHaveBeenCalledTimes(1)
})
it('should submit the csv file, poll status, and notify when import completes', async () => {
jest.useFakeTimers()
const { props } = renderComponent()
const fileTrigger = screen.getByTestId('mock-uploader')
fireEvent.click(fileTrigger)
const runButton = screen.getByRole('button', { name: 'appAnnotation.batchModal.run' })
expect(runButton).not.toBeDisabled()
annotationBatchImportMock.mockResolvedValue({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
checkAnnotationBatchImportProgressMock
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.PROCESSING })
.mockResolvedValueOnce({ job_id: 'job-1', job_status: ProcessStatus.COMPLETED })
await act(async () => {
fireEvent.click(runButton)
})
await waitFor(() => {
expect(annotationBatchImportMock).toHaveBeenCalledTimes(1)
})
const formData = annotationBatchImportMock.mock.calls[0][0].body as FormData
expect(formData.get('file')).toBe(lastUploadedFile)
await waitFor(() => {
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(1)
})
await act(async () => {
jest.runOnlyPendingTimers()
})
await waitFor(() => {
expect(checkAnnotationBatchImportProgressMock).toHaveBeenCalledTimes(2)
})
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'success',
message: 'appAnnotation.batchModal.completed',
})
expect(props.onAdded).toHaveBeenCalledTimes(1)
expect(props.onCancel).toHaveBeenCalledTimes(1)
})
jest.useRealTimers()
})
})

View File

@@ -245,7 +245,7 @@ describe('EditItem', () => {
expect(mockSave).toHaveBeenCalledWith('Test save content')
})
it('should show delete option and restore original content when delete is clicked', async () => {
it('should show delete option when content changes', async () => {
// Arrange
const mockSave = jest.fn().mockResolvedValue(undefined)
const props = {
@@ -267,13 +267,7 @@ describe('EditItem', () => {
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
expect(await screen.findByText('common.operation.delete')).toBeInTheDocument()
await user.click(screen.getByText('common.operation.delete'))
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
expect(screen.queryByText('common.operation.delete')).not.toBeInTheDocument()
expect(mockSave).toHaveBeenCalledWith('Modified content')
})
it('should handle keyboard interactions in edit mode', async () => {
@@ -399,68 +393,5 @@ describe('EditItem', () => {
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
expect(screen.getByText('Test content')).toBeInTheDocument()
})
it('should handle save failure gracefully in edit mode', async () => {
// Arrange
const mockSave = jest.fn().mockRejectedValueOnce(new Error('Save failed'))
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Enter edit mode and save (should fail)
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.type(textarea, 'New content')
// Save should fail but not throw
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
// Assert - Should remain in edit mode when save fails
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(mockSave).toHaveBeenCalledWith('New content')
})
it('should handle delete action failure gracefully', async () => {
// Arrange
const mockSave = jest.fn()
.mockResolvedValueOnce(undefined) // First save succeeds
.mockRejectedValueOnce(new Error('Delete failed')) // Delete fails
const props = {
...defaultProps,
onSave: mockSave,
}
const user = userEvent.setup()
// Act
render(<EditItem {...props} />)
// Edit content to show delete button
await user.click(screen.getByText('common.operation.edit'))
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified content')
// Save to create new content
await user.click(screen.getByRole('button', { name: 'common.operation.save' }))
await screen.findByText('common.operation.delete')
// Click delete (should fail but not throw)
await user.click(screen.getByText('common.operation.delete'))
// Assert - Delete action should handle error gracefully
expect(mockSave).toHaveBeenCalledTimes(2)
expect(mockSave).toHaveBeenNthCalledWith(1, 'Modified content')
expect(mockSave).toHaveBeenNthCalledWith(2, 'Test content')
// When delete fails, the delete button should still be visible (state not changed)
expect(screen.getByText('common.operation.delete')).toBeInTheDocument()
expect(screen.getByText('Modified content')).toBeInTheDocument()
})
})
})

View File

@@ -6,7 +6,7 @@ import { RiDeleteBinLine, RiEditFill, RiEditLine } from '@remixicon/react'
import { Robot, User } from '@/app/components/base/icons/src/public/avatar'
import Textarea from '@/app/components/base/textarea'
import Button from '@/app/components/base/button'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
export enum EditItemType {
Query = 'query',
@@ -52,14 +52,8 @@ const EditItem: FC<Props> = ({
}, [content])
const handleSave = async () => {
try {
await onSave(newContent)
setIsEdit(false)
}
catch {
// Keep edit mode open when save fails
// Error notification is handled by the parent component
}
await onSave(newContent)
setIsEdit(false)
}
const handleCancel = () => {
@@ -102,16 +96,9 @@ const EditItem: FC<Props> = ({
<div className='mr-2'>·</div>
<div
className='flex cursor-pointer items-center space-x-1'
onClick={async () => {
try {
await onSave(content)
// Only update UI state after successful delete
setNewContent(content)
}
catch {
// Delete action failed - error is already handled by parent
// UI state remains unchanged, user can retry
}
onClick={() => {
setNewContent(content)
onSave(content)
}}
>
<div className='h-3.5 w-3.5'>

View File

@@ -1,4 +1,4 @@
import { render, screen, waitFor } from '@testing-library/react'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Toast, { type IToastProps, type ToastHandle } from '@/app/components/base/toast'
import EditAnnotationModal from './index'
@@ -405,276 +405,4 @@ describe('EditAnnotationModal', () => {
expect(editLinks).toHaveLength(1) // Only answer should have edit button
})
})
// Error Handling (CRITICAL for coverage)
describe('Error Handling', () => {
it('should show error toast and skip callbacks when addAnnotation fails', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
// Mock API failure
mockAddAnnotation.mockRejectedValueOnce(new Error('API Error'))
// Act
render(<EditAnnotationModal {...props} />)
// Find and click edit link for query
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
// Find textarea and enter new content
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New query content')
// Click save button
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
await waitFor(() => {
expect(toastNotifySpy).toHaveBeenCalledWith({
message: 'API Error',
type: 'error',
})
})
expect(mockOnAdded).not.toHaveBeenCalled()
// Verify edit mode remains open (textarea should still be visible)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
})
it('should show fallback error message when addAnnotation error has no message', async () => {
// Arrange
const mockOnAdded = jest.fn()
const props = {
...defaultProps,
onAdded: mockOnAdded,
}
const user = userEvent.setup()
mockAddAnnotation.mockRejectedValueOnce({})
// Act
render(<EditAnnotationModal {...props} />)
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'New query content')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
await waitFor(() => {
expect(toastNotifySpy).toHaveBeenCalledWith({
message: 'common.api.actionFailed',
type: 'error',
})
})
expect(mockOnAdded).not.toHaveBeenCalled()
// Verify edit mode remains open (textarea should still be visible)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
})
it('should show error toast and skip callbacks when editAnnotation fails', async () => {
// Arrange
const mockOnEdited = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
messageId: 'test-message-id',
onEdited: mockOnEdited,
}
const user = userEvent.setup()
// Mock API failure
mockEditAnnotation.mockRejectedValueOnce(new Error('API Error'))
// Act
render(<EditAnnotationModal {...props} />)
// Edit query content
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
await waitFor(() => {
expect(toastNotifySpy).toHaveBeenCalledWith({
message: 'API Error',
type: 'error',
})
})
expect(mockOnEdited).not.toHaveBeenCalled()
// Verify edit mode remains open (textarea should still be visible)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
})
it('should show fallback error message when editAnnotation error is not an Error instance', async () => {
// Arrange
const mockOnEdited = jest.fn()
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
messageId: 'test-message-id',
onEdited: mockOnEdited,
}
const user = userEvent.setup()
mockEditAnnotation.mockRejectedValueOnce('oops')
// Act
render(<EditAnnotationModal {...props} />)
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Modified query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
await waitFor(() => {
expect(toastNotifySpy).toHaveBeenCalledWith({
message: 'common.api.actionFailed',
type: 'error',
})
})
expect(mockOnEdited).not.toHaveBeenCalled()
// Verify edit mode remains open (textarea should still be visible)
expect(screen.getByRole('textbox')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
})
})
// Billing & Plan Features
describe('Billing & Plan Features', () => {
it('should show createdAt time when provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
createdAt: 1701381000, // 2023-12-01 10:30:00
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Check that the formatted time appears somewhere in the component
const container = screen.getByRole('dialog')
expect(container).toHaveTextContent('2023-12-01 10:30:00')
})
it('should not show createdAt when not provided', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
// createdAt is undefined
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Should not contain any timestamp
const container = screen.getByRole('dialog')
expect(container).not.toHaveTextContent('2023-12-01 10:30:00')
})
it('should display remove section when annotationId exists', () => {
// Arrange
const props = {
...defaultProps,
annotationId: 'test-annotation-id',
}
// Act
render(<EditAnnotationModal {...props} />)
// Assert - Should have remove functionality
expect(screen.getByText('appAnnotation.editModal.removeThisCache')).toBeInTheDocument()
})
})
// Toast Notifications (Success)
describe('Toast Notifications', () => {
it('should show success notification when save operation completes', async () => {
// Arrange
const props = { ...defaultProps }
const user = userEvent.setup()
// Act
render(<EditAnnotationModal {...props} />)
const editLinks = screen.getAllByText(/common\.operation\.edit/i)
await user.click(editLinks[0])
const textarea = screen.getByRole('textbox')
await user.clear(textarea)
await user.type(textarea, 'Updated query')
const saveButton = screen.getByRole('button', { name: 'common.operation.save' })
await user.click(saveButton)
// Assert
await waitFor(() => {
expect(toastNotifySpy).toHaveBeenCalledWith({
message: 'common.api.actionSuccess',
type: 'success',
})
})
})
})
// React.memo Performance Testing
describe('React.memo Performance', () => {
it('should not re-render when props are the same', () => {
// Arrange
const props = { ...defaultProps }
const { rerender } = render(<EditAnnotationModal {...props} />)
// Act - Re-render with same props
rerender(<EditAnnotationModal {...props} />)
// Assert - Component should still be visible (no errors thrown)
expect(screen.getByText('appAnnotation.editModal.title')).toBeInTheDocument()
})
it('should re-render when props change', () => {
// Arrange
const props = { ...defaultProps }
const { rerender } = render(<EditAnnotationModal {...props} />)
// Act - Re-render with different props
const newProps = { ...props, query: 'New query content' }
rerender(<EditAnnotationModal {...newProps} />)
// Assert - Should show new content
expect(screen.getByText('New query content')).toBeInTheDocument()
})
})
})

View File

@@ -53,39 +53,27 @@ const EditAnnotationModal: FC<Props> = ({
postQuery = editedContent
else
postAnswer = editedContent
try {
if (!isAdd) {
await editAnnotation(appId, annotationId, {
message_id: messageId,
question: postQuery,
answer: postAnswer,
})
onEdited(postQuery, postAnswer)
}
else {
const res = await addAnnotation(appId, {
question: postQuery,
answer: postAnswer,
message_id: messageId,
})
onAdded(res.id, res.account?.name ?? '', postQuery, postAnswer)
}
if (!isAdd) {
await editAnnotation(appId, annotationId, {
message_id: messageId,
question: postQuery,
answer: postAnswer,
})
onEdited(postQuery, postAnswer)
}
else {
const res: any = await addAnnotation(appId, {
question: postQuery,
answer: postAnswer,
message_id: messageId,
})
onAdded(res.id, res.account?.name, postQuery, postAnswer)
}
Toast.notify({
message: t('common.api.actionSuccess') as string,
type: 'success',
})
}
catch (error) {
const fallbackMessage = t('common.api.actionFailed') as string
const message = error instanceof Error && error.message ? error.message : fallbackMessage
Toast.notify({
message,
type: 'error',
})
// Re-throw to preserve edit mode behavior for UI components
throw error
}
Toast.notify({
message: t('common.api.actionSuccess') as string,
type: 'success',
})
}
const [showModal, setShowModal] = useState(false)

View File

@@ -1,13 +0,0 @@
import React from 'react'
import { render, screen } from '@testing-library/react'
import EmptyElement from './empty-element'
describe('EmptyElement', () => {
it('should render the empty state copy and supporting icon', () => {
const { container } = render(<EmptyElement />)
expect(screen.getByText('appAnnotation.noData.title')).toBeInTheDocument()
expect(screen.getByText('appAnnotation.noData.description')).toBeInTheDocument()
expect(container.querySelector('svg')).not.toBeNull()
})
})

View File

@@ -1,70 +0,0 @@
import React from 'react'
import { fireEvent, render, screen } from '@testing-library/react'
import Filter, { type QueryParam } from './filter'
import useSWR from 'swr'
jest.mock('swr', () => ({
__esModule: true,
default: jest.fn(),
}))
jest.mock('@/service/log', () => ({
fetchAnnotationsCount: jest.fn(),
}))
const mockUseSWR = useSWR as unknown as jest.Mock
describe('Filter', () => {
const appId = 'app-1'
const childContent = 'child-content'
beforeEach(() => {
jest.clearAllMocks()
})
it('should render nothing until annotation count is fetched', () => {
mockUseSWR.mockReturnValue({ data: undefined })
const { container } = render(
<Filter
appId={appId}
queryParams={{ keyword: '' }}
setQueryParams={jest.fn()}
>
<div>{childContent}</div>
</Filter>,
)
expect(container.firstChild).toBeNull()
expect(mockUseSWR).toHaveBeenCalledWith(
{ url: `/apps/${appId}/annotations/count` },
expect.any(Function),
)
})
it('should propagate keyword changes and clearing behavior', () => {
mockUseSWR.mockReturnValue({ data: { total: 20 } })
const queryParams: QueryParam = { keyword: 'prefill' }
const setQueryParams = jest.fn()
const { container } = render(
<Filter
appId={appId}
queryParams={queryParams}
setQueryParams={setQueryParams}
>
<div>{childContent}</div>
</Filter>,
)
const input = screen.getByPlaceholderText('common.operation.search') as HTMLInputElement
fireEvent.change(input, { target: { value: 'updated' } })
expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: 'updated' })
const clearButton = input.parentElement?.querySelector('div.cursor-pointer') as HTMLElement
fireEvent.click(clearButton)
expect(setQueryParams).toHaveBeenCalledWith({ ...queryParams, keyword: '' })
expect(container).toHaveTextContent(childContent)
})
})

View File

@@ -1,439 +0,0 @@
import * as React from 'react'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import type { ComponentProps } from 'react'
import HeaderOptions from './index'
import I18NContext from '@/context/i18n'
import { LanguagesSupported } from '@/i18n-config/language'
import type { AnnotationItemBasic } from '../type'
import { clearAllAnnotations, fetchExportAnnotationList } from '@/service/annotation'
jest.mock('@headlessui/react', () => {
type PopoverContextValue = { open: boolean; setOpen: (open: boolean) => void }
type MenuContextValue = { open: boolean; setOpen: (open: boolean) => void }
const PopoverContext = React.createContext<PopoverContextValue | null>(null)
const MenuContext = React.createContext<MenuContextValue | null>(null)
const Popover = ({ children }: { children: React.ReactNode | ((props: { open: boolean }) => React.ReactNode) }) => {
const [open, setOpen] = React.useState(false)
const value = React.useMemo(() => ({ open, setOpen }), [open])
return (
<PopoverContext.Provider value={value}>
{typeof children === 'function' ? children({ open }) : children}
</PopoverContext.Provider>
)
}
const PopoverButton = React.forwardRef(({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }, ref: React.Ref<HTMLButtonElement>) => {
const context = React.useContext(PopoverContext)
const handleClick = () => {
context?.setOpen(!context.open)
onClick?.()
}
return (
<button
ref={ref}
type="button"
aria-expanded={context?.open ?? false}
onClick={handleClick}
{...props}
>
{children}
</button>
)
})
const PopoverPanel = React.forwardRef(({ children, ...props }: { children: React.ReactNode | ((props: { close: () => void }) => React.ReactNode) }, ref: React.Ref<HTMLDivElement>) => {
const context = React.useContext(PopoverContext)
if (!context?.open) return null
const content = typeof children === 'function' ? children({ close: () => context.setOpen(false) }) : children
return (
<div ref={ref} {...props}>
{content}
</div>
)
})
const Menu = ({ children }: { children: React.ReactNode }) => {
const [open, setOpen] = React.useState(false)
const value = React.useMemo(() => ({ open, setOpen }), [open])
return (
<MenuContext.Provider value={value}>
{children}
</MenuContext.Provider>
)
}
const MenuButton = ({ onClick, children, ...props }: { onClick?: () => void; children?: React.ReactNode }) => {
const context = React.useContext(MenuContext)
const handleClick = () => {
context?.setOpen(!context.open)
onClick?.()
}
return (
<button type="button" aria-expanded={context?.open ?? false} onClick={handleClick} {...props}>
{children}
</button>
)
}
const MenuItems = ({ children, ...props }: { children: React.ReactNode }) => {
const context = React.useContext(MenuContext)
if (!context?.open) return null
return (
<div {...props}>
{children}
</div>
)
}
return {
Dialog: ({ open, children, className }: { open?: boolean; children: React.ReactNode; className?: string }) => {
if (open === false) return null
return (
<div role="dialog" className={className}>
{children}
</div>
)
},
DialogBackdrop: ({ children, className, onClick }: { children?: React.ReactNode; className?: string; onClick?: () => void }) => (
<div className={className} onClick={onClick}>
{children}
</div>
),
DialogPanel: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => (
<div className={className} {...props}>
{children}
</div>
),
DialogTitle: ({ children, className, ...props }: { children: React.ReactNode; className?: string }) => (
<div className={className} {...props}>
{children}
</div>
),
Popover,
PopoverButton,
PopoverPanel,
Menu,
MenuButton,
MenuItems,
Transition: ({ show = true, children }: { show?: boolean; children: React.ReactNode }) => (show ? <>{children}</> : null),
TransitionChild: ({ children }: { children: React.ReactNode }) => <>{children}</>,
}
})
let lastCSVDownloaderProps: Record<string, unknown> | undefined
const mockCSVDownloader = jest.fn(({ children, ...props }) => {
lastCSVDownloaderProps = props
return (
<div data-testid="csv-downloader">
{children}
</div>
)
})
jest.mock('react-papaparse', () => ({
useCSVDownloader: () => ({
CSVDownloader: (props: any) => mockCSVDownloader(props),
Type: { Link: 'link' },
}),
}))
jest.mock('@/service/annotation', () => ({
fetchExportAnnotationList: jest.fn(),
clearAllAnnotations: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: () => ({
plan: {
usage: { annotatedResponse: 0 },
total: { annotatedResponse: 10 },
},
enableBilling: false,
}),
}))
jest.mock('@/app/components/billing/annotation-full', () => ({
__esModule: true,
default: () => <div data-testid="annotation-full" />,
}))
type HeaderOptionsProps = ComponentProps<typeof HeaderOptions>
const renderComponent = (
props: Partial<HeaderOptionsProps> = {},
locale: string = LanguagesSupported[0] as string,
) => {
const defaultProps: HeaderOptionsProps = {
appId: 'test-app-id',
onAdd: jest.fn(),
onAdded: jest.fn(),
controlUpdateList: 0,
...props,
}
return render(
<I18NContext.Provider
value={{
locale,
i18n: {},
setLocaleOnClient: jest.fn(),
}}
>
<HeaderOptions {...defaultProps} />
</I18NContext.Provider>,
)
}
const openOperationsPopover = async (user: ReturnType<typeof userEvent.setup>) => {
const trigger = document.querySelector('button.btn.btn-secondary') as HTMLButtonElement
expect(trigger).toBeTruthy()
await user.click(trigger)
}
const expandExportMenu = async (user: ReturnType<typeof userEvent.setup>) => {
await openOperationsPopover(user)
const exportLabel = await screen.findByText('appAnnotation.table.header.bulkExport')
const exportButton = exportLabel.closest('button') as HTMLButtonElement
expect(exportButton).toBeTruthy()
await user.click(exportButton)
}
const getExportButtons = async () => {
const csvLabel = await screen.findByText('CSV')
const jsonLabel = await screen.findByText('JSONL')
const csvButton = csvLabel.closest('button') as HTMLButtonElement
const jsonButton = jsonLabel.closest('button') as HTMLButtonElement
expect(csvButton).toBeTruthy()
expect(jsonButton).toBeTruthy()
return {
csvButton,
jsonButton,
}
}
const clickOperationAction = async (
user: ReturnType<typeof userEvent.setup>,
translationKey: string,
) => {
const label = await screen.findByText(translationKey)
const button = label.closest('button') as HTMLButtonElement
expect(button).toBeTruthy()
await user.click(button)
}
const mockAnnotations: AnnotationItemBasic[] = [
{
question: 'Question 1',
answer: 'Answer 1',
},
]
const mockedFetchAnnotations = jest.mocked(fetchExportAnnotationList)
const mockedClearAllAnnotations = jest.mocked(clearAllAnnotations)
describe('HeaderOptions', () => {
beforeEach(() => {
jest.clearAllMocks()
jest.useRealTimers()
mockCSVDownloader.mockClear()
lastCSVDownloaderProps = undefined
mockedFetchAnnotations.mockResolvedValue({ data: [] })
})
it('should fetch annotations on mount and render enabled export actions when data exist', async () => {
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
const user = userEvent.setup()
renderComponent()
await waitFor(() => {
expect(mockedFetchAnnotations).toHaveBeenCalledWith('test-app-id')
})
await expandExportMenu(user)
const { csvButton, jsonButton } = await getExportButtons()
expect(csvButton).not.toBeDisabled()
expect(jsonButton).not.toBeDisabled()
await waitFor(() => {
expect(lastCSVDownloaderProps).toMatchObject({
bom: true,
filename: 'annotations-en-US',
type: 'link',
data: [
['Question', 'Answer'],
['Question 1', 'Answer 1'],
],
})
})
})
it('should disable export actions when there are no annotations', async () => {
const user = userEvent.setup()
renderComponent()
await expandExportMenu(user)
const { csvButton, jsonButton } = await getExportButtons()
expect(csvButton).toBeDisabled()
expect(jsonButton).toBeDisabled()
expect(lastCSVDownloaderProps).toMatchObject({
data: [['Question', 'Answer']],
})
})
it('should open the add annotation modal and forward the onAdd callback', async () => {
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
const user = userEvent.setup()
const onAdd = jest.fn().mockResolvedValue(undefined)
renderComponent({ onAdd })
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalled())
await user.click(
screen.getByRole('button', { name: 'appAnnotation.table.header.addAnnotation' }),
)
await screen.findByText('appAnnotation.addModal.title')
const questionInput = screen.getByPlaceholderText('appAnnotation.addModal.queryPlaceholder')
const answerInput = screen.getByPlaceholderText('appAnnotation.addModal.answerPlaceholder')
await user.type(questionInput, 'Integration question')
await user.type(answerInput, 'Integration answer')
await user.click(screen.getByRole('button', { name: 'common.operation.add' }))
await waitFor(() => {
expect(onAdd).toHaveBeenCalledWith({
question: 'Integration question',
answer: 'Integration answer',
})
})
})
it('should allow bulk import through the batch modal', async () => {
const user = userEvent.setup()
const onAdded = jest.fn()
renderComponent({ onAdded })
await openOperationsPopover(user)
await clickOperationAction(user, 'appAnnotation.table.header.bulkImport')
expect(await screen.findByText('appAnnotation.batchModal.title')).toBeInTheDocument()
await user.click(
screen.getByRole('button', { name: 'appAnnotation.batchModal.cancel' }),
)
expect(onAdded).not.toHaveBeenCalled()
})
it('should trigger JSONL download with locale-specific filename', async () => {
mockedFetchAnnotations.mockResolvedValue({ data: mockAnnotations })
const user = userEvent.setup()
const originalCreateElement = document.createElement.bind(document)
const anchor = originalCreateElement('a') as HTMLAnchorElement
const clickSpy = jest.spyOn(anchor, 'click').mockImplementation(jest.fn())
const createElementSpy = jest
.spyOn(document, 'createElement')
.mockImplementation((tagName: Parameters<Document['createElement']>[0]) => {
if (tagName === 'a')
return anchor
return originalCreateElement(tagName)
})
const objectURLSpy = jest
.spyOn(URL, 'createObjectURL')
.mockReturnValue('blob://mock-url')
const revokeSpy = jest.spyOn(URL, 'revokeObjectURL').mockImplementation(jest.fn())
renderComponent({}, LanguagesSupported[1] as string)
await expandExportMenu(user)
await waitFor(() => expect(mockCSVDownloader).toHaveBeenCalled())
const { jsonButton } = await getExportButtons()
await user.click(jsonButton)
expect(createElementSpy).toHaveBeenCalled()
expect(anchor.download).toBe(`annotations-${LanguagesSupported[1]}.jsonl`)
expect(clickSpy).toHaveBeenCalled()
expect(revokeSpy).toHaveBeenCalledWith('blob://mock-url')
const blobArg = objectURLSpy.mock.calls[0][0] as Blob
await expect(blobArg.text()).resolves.toContain('"Question 1"')
clickSpy.mockRestore()
createElementSpy.mockRestore()
objectURLSpy.mockRestore()
revokeSpy.mockRestore()
})
it('should clear all annotations when confirmation succeeds', async () => {
mockedClearAllAnnotations.mockResolvedValue(undefined)
const user = userEvent.setup()
const onAdded = jest.fn()
renderComponent({ onAdded })
await openOperationsPopover(user)
await clickOperationAction(user, 'appAnnotation.table.header.clearAll')
await screen.findByText('appAnnotation.table.header.clearAllConfirm')
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
await user.click(confirmButton)
await waitFor(() => {
expect(mockedClearAllAnnotations).toHaveBeenCalledWith('test-app-id')
expect(onAdded).toHaveBeenCalled()
})
})
it('should handle clear all failures gracefully', async () => {
const consoleSpy = jest.spyOn(console, 'error').mockImplementation(jest.fn())
mockedClearAllAnnotations.mockRejectedValue(new Error('network'))
const user = userEvent.setup()
const onAdded = jest.fn()
renderComponent({ onAdded })
await openOperationsPopover(user)
await clickOperationAction(user, 'appAnnotation.table.header.clearAll')
await screen.findByText('appAnnotation.table.header.clearAllConfirm')
const confirmButton = screen.getByRole('button', { name: 'common.operation.confirm' })
await user.click(confirmButton)
await waitFor(() => {
expect(mockedClearAllAnnotations).toHaveBeenCalled()
expect(onAdded).not.toHaveBeenCalled()
expect(consoleSpy).toHaveBeenCalled()
})
consoleSpy.mockRestore()
})
it('should refetch annotations when controlUpdateList changes', async () => {
const view = renderComponent({ controlUpdateList: 0 })
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(1))
view.rerender(
<I18NContext.Provider
value={{
locale: LanguagesSupported[0] as string,
i18n: {},
setLocaleOnClient: jest.fn(),
}}
>
<HeaderOptions
appId="test-app-id"
onAdd={jest.fn()}
onAdded={jest.fn()}
controlUpdateList={1}
/>
</I18NContext.Provider>,
)
await waitFor(() => expect(mockedFetchAnnotations).toHaveBeenCalledTimes(2))
})
})

View File

@@ -17,7 +17,7 @@ import Button from '../../../base/button'
import AddAnnotationModal from '../add-annotation-modal'
import type { AnnotationItemBasic } from '../type'
import BatchAddModal from '../batch-add-annotation-modal'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import CustomPopover from '@/app/components/base/popover'
import { FileDownload02, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files'
import { ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'

View File

@@ -1,233 +0,0 @@
import React from 'react'
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import Annotation from './index'
import type { AnnotationItem } from './type'
import { JobStatus } from './type'
import { type App, AppModeEnum } from '@/types/app'
import {
addAnnotation,
delAnnotation,
delAnnotations,
fetchAnnotationConfig,
fetchAnnotationList,
queryAnnotationJobStatus,
} from '@/service/annotation'
import { useProviderContext } from '@/context/provider-context'
import Toast from '@/app/components/base/toast'
jest.mock('@/app/components/base/toast', () => ({
__esModule: true,
default: { notify: jest.fn() },
}))
jest.mock('ahooks', () => ({
useDebounce: (value: any) => value,
}))
jest.mock('@/service/annotation', () => ({
addAnnotation: jest.fn(),
delAnnotation: jest.fn(),
delAnnotations: jest.fn(),
fetchAnnotationConfig: jest.fn(),
editAnnotation: jest.fn(),
fetchAnnotationList: jest.fn(),
queryAnnotationJobStatus: jest.fn(),
updateAnnotationScore: jest.fn(),
updateAnnotationStatus: jest.fn(),
}))
jest.mock('@/context/provider-context', () => ({
useProviderContext: jest.fn(),
}))
jest.mock('./filter', () => ({ children }: { children: React.ReactNode }) => (
<div data-testid="filter">{children}</div>
))
jest.mock('./empty-element', () => () => <div data-testid="empty-element" />)
jest.mock('./header-opts', () => (props: any) => (
<div data-testid="header-opts">
<button data-testid="trigger-add" onClick={() => props.onAdd({ question: 'new question', answer: 'new answer' })}>
add
</button>
</div>
))
let latestListProps: any
jest.mock('./list', () => (props: any) => {
latestListProps = props
if (!props.list.length)
return <div data-testid="list-empty" />
return (
<div data-testid="list">
<button data-testid="list-view" onClick={() => props.onView(props.list[0])}>view</button>
<button data-testid="list-remove" onClick={() => props.onRemove(props.list[0].id)}>remove</button>
<button data-testid="list-batch-delete" onClick={() => props.onBatchDelete()}>batch-delete</button>
</div>
)
})
jest.mock('./view-annotation-modal', () => (props: any) => {
if (!props.isShow)
return null
return (
<div data-testid="view-modal">
<div>{props.item.question}</div>
<button data-testid="view-modal-remove" onClick={props.onRemove}>remove</button>
<button data-testid="view-modal-close" onClick={props.onHide}>close</button>
</div>
)
})
jest.mock('@/app/components/base/pagination', () => () => <div data-testid="pagination" />)
jest.mock('@/app/components/base/loading', () => () => <div data-testid="loading" />)
jest.mock('@/app/components/base/features/new-feature-panel/annotation-reply/config-param-modal', () => (props: any) => props.isShow ? <div data-testid="config-modal" /> : null)
jest.mock('@/app/components/billing/annotation-full/modal', () => (props: any) => props.show ? <div data-testid="annotation-full-modal" /> : null)
const mockNotify = Toast.notify as jest.Mock
const addAnnotationMock = addAnnotation as jest.Mock
const delAnnotationMock = delAnnotation as jest.Mock
const delAnnotationsMock = delAnnotations as jest.Mock
const fetchAnnotationConfigMock = fetchAnnotationConfig as jest.Mock
const fetchAnnotationListMock = fetchAnnotationList as jest.Mock
const queryAnnotationJobStatusMock = queryAnnotationJobStatus as jest.Mock
const useProviderContextMock = useProviderContext as jest.Mock
const appDetail = {
id: 'app-id',
mode: AppModeEnum.CHAT,
} as App
const createAnnotation = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
id: overrides.id ?? 'annotation-1',
question: overrides.question ?? 'Question 1',
answer: overrides.answer ?? 'Answer 1',
created_at: overrides.created_at ?? 1700000000,
hit_count: overrides.hit_count ?? 0,
})
const renderComponent = () => render(<Annotation appDetail={appDetail} />)
describe('Annotation', () => {
beforeEach(() => {
jest.clearAllMocks()
latestListProps = undefined
fetchAnnotationConfigMock.mockResolvedValue({
id: 'config-id',
enabled: false,
embedding_model: {
embedding_model_name: 'model',
embedding_provider_name: 'provider',
},
score_threshold: 0.5,
})
fetchAnnotationListMock.mockResolvedValue({ data: [], total: 0 })
queryAnnotationJobStatusMock.mockResolvedValue({ job_status: JobStatus.completed })
useProviderContextMock.mockReturnValue({
plan: {
usage: { annotatedResponse: 0 },
total: { annotatedResponse: 10 },
},
enableBilling: false,
})
})
it('should render empty element when no annotations are returned', async () => {
renderComponent()
expect(await screen.findByTestId('empty-element')).toBeInTheDocument()
expect(fetchAnnotationListMock).toHaveBeenCalledWith(appDetail.id, expect.objectContaining({
page: 1,
keyword: '',
}))
})
it('should handle annotation creation and refresh list data', async () => {
const annotation = createAnnotation()
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
addAnnotationMock.mockResolvedValue(undefined)
renderComponent()
await screen.findByTestId('list')
fireEvent.click(screen.getByTestId('trigger-add'))
await waitFor(() => {
expect(addAnnotationMock).toHaveBeenCalledWith(appDetail.id, { question: 'new question', answer: 'new answer' })
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
message: 'common.api.actionSuccess',
type: 'success',
}))
})
expect(fetchAnnotationListMock).toHaveBeenCalledTimes(2)
})
it('should support viewing items and running batch deletion success flow', async () => {
const annotation = createAnnotation()
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
delAnnotationsMock.mockResolvedValue(undefined)
delAnnotationMock.mockResolvedValue(undefined)
renderComponent()
await screen.findByTestId('list')
await act(async () => {
latestListProps.onSelectedIdsChange([annotation.id])
})
await waitFor(() => {
expect(latestListProps.selectedIds).toEqual([annotation.id])
})
await act(async () => {
await latestListProps.onBatchDelete()
})
await waitFor(() => {
expect(delAnnotationsMock).toHaveBeenCalledWith(appDetail.id, [annotation.id])
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
}))
expect(latestListProps.selectedIds).toEqual([])
})
fireEvent.click(screen.getByTestId('list-view'))
expect(screen.getByTestId('view-modal')).toBeInTheDocument()
await act(async () => {
fireEvent.click(screen.getByTestId('view-modal-remove'))
})
await waitFor(() => {
expect(delAnnotationMock).toHaveBeenCalledWith(appDetail.id, annotation.id)
})
})
it('should show an error notification when batch deletion fails', async () => {
const annotation = createAnnotation()
fetchAnnotationListMock.mockResolvedValue({ data: [annotation], total: 1 })
const error = new Error('failed')
delAnnotationsMock.mockRejectedValue(error)
renderComponent()
await screen.findByTestId('list')
await act(async () => {
latestListProps.onSelectedIdsChange([annotation.id])
})
await waitFor(() => {
expect(latestListProps.selectedIds).toEqual([annotation.id])
})
await act(async () => {
await latestListProps.onBatchDelete()
})
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: error.message,
})
expect(latestListProps.selectedIds).toEqual([annotation.id])
})
})
})

View File

@@ -25,7 +25,7 @@ import { sleep } from '@/utils'
import { useProviderContext } from '@/context/provider-context'
import AnnotationFullModal from '@/app/components/billing/annotation-full/modal'
import { type App, AppModeEnum } from '@/types/app'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { delAnnotations } from '@/service/annotation'
type Props = {

View File

@@ -1,116 +0,0 @@
import React from 'react'
import { fireEvent, render, screen, within } from '@testing-library/react'
import List from './list'
import type { AnnotationItem } from './type'
const mockFormatTime = jest.fn(() => 'formatted-time')
jest.mock('@/hooks/use-timestamp', () => ({
__esModule: true,
default: () => ({
formatTime: mockFormatTime,
}),
}))
const createAnnotation = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
id: overrides.id ?? 'annotation-id',
question: overrides.question ?? 'question 1',
answer: overrides.answer ?? 'answer 1',
created_at: overrides.created_at ?? 1700000000,
hit_count: overrides.hit_count ?? 2,
})
const getCheckboxes = (container: HTMLElement) => container.querySelectorAll('[data-testid^="checkbox"]')
describe('List', () => {
beforeEach(() => {
jest.clearAllMocks()
})
it('should render annotation rows and call onView when clicking a row', () => {
const item = createAnnotation()
const onView = jest.fn()
render(
<List
list={[item]}
onView={onView}
onRemove={jest.fn()}
selectedIds={[]}
onSelectedIdsChange={jest.fn()}
onBatchDelete={jest.fn()}
onCancel={jest.fn()}
/>,
)
fireEvent.click(screen.getByText(item.question))
expect(onView).toHaveBeenCalledWith(item)
expect(mockFormatTime).toHaveBeenCalledWith(item.created_at, 'appLog.dateTimeFormat')
})
it('should toggle single and bulk selection states', () => {
const list = [createAnnotation({ id: 'a', question: 'A' }), createAnnotation({ id: 'b', question: 'B' })]
const onSelectedIdsChange = jest.fn()
const { container, rerender } = render(
<List
list={list}
onView={jest.fn()}
onRemove={jest.fn()}
selectedIds={[]}
onSelectedIdsChange={onSelectedIdsChange}
onBatchDelete={jest.fn()}
onCancel={jest.fn()}
/>,
)
const checkboxes = getCheckboxes(container)
fireEvent.click(checkboxes[1])
expect(onSelectedIdsChange).toHaveBeenCalledWith(['a'])
rerender(
<List
list={list}
onView={jest.fn()}
onRemove={jest.fn()}
selectedIds={['a']}
onSelectedIdsChange={onSelectedIdsChange}
onBatchDelete={jest.fn()}
onCancel={jest.fn()}
/>,
)
const updatedCheckboxes = getCheckboxes(container)
fireEvent.click(updatedCheckboxes[1])
expect(onSelectedIdsChange).toHaveBeenCalledWith([])
fireEvent.click(updatedCheckboxes[0])
expect(onSelectedIdsChange).toHaveBeenCalledWith(['a', 'b'])
})
it('should confirm before removing an annotation and expose batch actions', async () => {
const item = createAnnotation({ id: 'to-delete', question: 'Delete me' })
const onRemove = jest.fn()
render(
<List
list={[item]}
onView={jest.fn()}
onRemove={onRemove}
selectedIds={[item.id]}
onSelectedIdsChange={jest.fn()}
onBatchDelete={jest.fn()}
onCancel={jest.fn()}
/>,
)
const row = screen.getByText(item.question).closest('tr') as HTMLTableRowElement
const actionButtons = within(row).getAllByRole('button')
fireEvent.click(actionButtons[1])
expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' })
fireEvent.click(confirmButton)
expect(onRemove).toHaveBeenCalledWith(item.id)
expect(screen.getByText('appAnnotation.batchAction.selected')).toBeInTheDocument()
})
})

View File

@@ -7,7 +7,7 @@ import type { AnnotationItem } from './type'
import RemoveAnnotationConfirmModal from './remove-annotation-confirm-modal'
import ActionButton from '@/app/components/base/action-button'
import useTimestamp from '@/hooks/use-timestamp'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Checkbox from '@/app/components/base/checkbox'
import BatchAction from './batch-action'

View File

@@ -12,12 +12,6 @@ export type AnnotationItem = {
hit_count: number
}
export type AnnotationCreateResponse = AnnotationItem & {
account?: {
name?: string
}
}
export type HitHistoryItem = {
id: string
question: string

View File

@@ -1,158 +0,0 @@
import React from 'react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import ViewAnnotationModal from './index'
import type { AnnotationItem, HitHistoryItem } from '../type'
import { fetchHitHistoryList } from '@/service/annotation'
const mockFormatTime = jest.fn(() => 'formatted-time')
jest.mock('@/hooks/use-timestamp', () => ({
__esModule: true,
default: () => ({
formatTime: mockFormatTime,
}),
}))
jest.mock('@/service/annotation', () => ({
fetchHitHistoryList: jest.fn(),
}))
jest.mock('../edit-annotation-modal/edit-item', () => {
const EditItemType = {
Query: 'query',
Answer: 'answer',
}
return {
__esModule: true,
default: ({ type, content, onSave }: { type: string; content: string; onSave: (value: string) => void }) => (
<div>
<div data-testid={`content-${type}`}>{content}</div>
<button data-testid={`edit-${type}`} onClick={() => onSave(`${type}-updated`)}>edit-{type}</button>
</div>
),
EditItemType,
}
})
const fetchHitHistoryListMock = fetchHitHistoryList as jest.Mock
const createAnnotationItem = (overrides: Partial<AnnotationItem> = {}): AnnotationItem => ({
id: overrides.id ?? 'annotation-id',
question: overrides.question ?? 'question',
answer: overrides.answer ?? 'answer',
created_at: overrides.created_at ?? 1700000000,
hit_count: overrides.hit_count ?? 0,
})
const createHitHistoryItem = (overrides: Partial<HitHistoryItem> = {}): HitHistoryItem => ({
id: overrides.id ?? 'hit-id',
question: overrides.question ?? 'query',
match: overrides.match ?? 'match',
response: overrides.response ?? 'response',
source: overrides.source ?? 'source',
score: overrides.score ?? 0.42,
created_at: overrides.created_at ?? 1700000000,
})
const renderComponent = (props?: Partial<React.ComponentProps<typeof ViewAnnotationModal>>) => {
const item = createAnnotationItem()
const mergedProps: React.ComponentProps<typeof ViewAnnotationModal> = {
appId: 'app-id',
isShow: true,
onHide: jest.fn(),
item,
onSave: jest.fn().mockResolvedValue(undefined),
onRemove: jest.fn().mockResolvedValue(undefined),
...props,
}
return {
...render(<ViewAnnotationModal {...mergedProps} />),
props: mergedProps,
}
}
describe('ViewAnnotationModal', () => {
beforeEach(() => {
jest.clearAllMocks()
fetchHitHistoryListMock.mockResolvedValue({ data: [], total: 0 })
})
it('should render annotation tab and allow saving updated query', async () => {
// Arrange
const { props } = renderComponent()
await waitFor(() => {
expect(fetchHitHistoryListMock).toHaveBeenCalled()
})
// Act
fireEvent.click(screen.getByTestId('edit-query'))
// Assert
await waitFor(() => {
expect(props.onSave).toHaveBeenCalledWith('query-updated', props.item.answer)
})
})
it('should render annotation tab and allow saving updated answer', async () => {
// Arrange
const { props } = renderComponent()
await waitFor(() => {
expect(fetchHitHistoryListMock).toHaveBeenCalled()
})
// Act
fireEvent.click(screen.getByTestId('edit-answer'))
// Assert
await waitFor(() => {
expect(props.onSave).toHaveBeenCalledWith(props.item.question, 'answer-updated')
},
)
})
it('should switch to hit history tab and show no data message', async () => {
// Arrange
const { props } = renderComponent()
await waitFor(() => {
expect(fetchHitHistoryListMock).toHaveBeenCalled()
})
// Act
fireEvent.click(screen.getByText('appAnnotation.viewModal.hitHistory'))
// Assert
expect(await screen.findByText('appAnnotation.viewModal.noHitHistory')).toBeInTheDocument()
expect(mockFormatTime).toHaveBeenCalledWith(props.item.created_at, 'appLog.dateTimeFormat')
})
it('should render hit history entries with pagination badge when data exists', async () => {
const hits = [createHitHistoryItem({ question: 'user input' }), createHitHistoryItem({ id: 'hit-2', question: 'second' })]
fetchHitHistoryListMock.mockResolvedValue({ data: hits, total: 15 })
renderComponent()
fireEvent.click(await screen.findByText('appAnnotation.viewModal.hitHistory'))
expect(await screen.findByText('user input')).toBeInTheDocument()
expect(screen.getByText('15 appAnnotation.viewModal.hits')).toBeInTheDocument()
expect(mockFormatTime).toHaveBeenCalledWith(hits[0].created_at, 'appLog.dateTimeFormat')
})
it('should confirm before removing the annotation and hide on success', async () => {
const { props } = renderComponent()
fireEvent.click(screen.getByText('appAnnotation.editModal.removeThisCache'))
expect(await screen.findByText('appDebug.feature.annotation.removeConfirm')).toBeInTheDocument()
const confirmButton = await screen.findByRole('button', { name: 'common.operation.confirm' })
fireEvent.click(confirmButton)
await waitFor(() => {
expect(props.onRemove).toHaveBeenCalledTimes(1)
expect(props.onHide).toHaveBeenCalledTimes(1)
})
})
})

View File

@@ -14,7 +14,7 @@ import TabSlider from '@/app/components/base/tab-slider-plain'
import { fetchHitHistoryList } from '@/service/annotation'
import { APP_PAGE_LIMIT } from '@/config'
import useTimestamp from '@/hooks/use-timestamp'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
type Props = {
appId: string

View File

@@ -2,7 +2,7 @@ import { Fragment, useCallback } from 'react'
import type { ReactNode } from 'react'
import { Dialog, Transition } from '@headlessui/react'
import { RiCloseLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
type DialogProps = {
className?: string

View File

@@ -181,7 +181,7 @@ describe('AccessControlItem', () => {
expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION)
})
it('should keep current menu when clicking the selected access type', () => {
it('should render selected styles when the current menu matches the type', () => {
useAccessControlStore.setState({ currentMenu: AccessMode.ORGANIZATION })
render(
<AccessControlItem type={AccessMode.ORGANIZATION}>
@@ -190,9 +190,8 @@ describe('AccessControlItem', () => {
)
const option = screen.getByText('Organization Only').parentElement as HTMLElement
fireEvent.click(option)
expect(useAccessControlStore.getState().currentMenu).toBe(AccessMode.ORGANIZATION)
expect(option.className).toContain('border-[1.5px]')
expect(option.className).not.toContain('cursor-pointer')
})
})

View File

@@ -11,7 +11,7 @@ import Input from '../../base/input'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem'
import Loading from '../../base/loading'
import useAccessControlStore from '../../../../context/access-control-store'
import { cn } from '@/utils/classnames'
import classNames from '@/utils/classnames'
import { useSearchForWhiteListCandidates } from '@/service/access-control'
import type { AccessControlAccount, AccessControlGroup, Subject, SubjectAccount, SubjectGroup } from '@/models/access-control'
import { SubjectType } from '@/models/access-control'
@@ -106,7 +106,7 @@ function SelectedGroupsBreadCrumb() {
setSelectedGroupsForBreadcrumb([])
}, [setSelectedGroupsForBreadcrumb])
return <div className='flex h-7 items-center gap-x-0.5 px-2 py-0.5'>
<span className={cn('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
<span className={classNames('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('app.accessControlDialog.operateGroupAndMember.allMembers')}</span>
{selectedGroupsForBreadcrumb.map((group, index) => {
return <div key={index} className='system-xs-regular flex items-center gap-x-0.5 text-text-tertiary'>
<span>/</span>
@@ -198,7 +198,7 @@ type BaseItemProps = {
children: React.ReactNode
}
function BaseItem({ children, className }: BaseItemProps) {
return <div className={cn('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
return <div className={classNames('flex cursor-pointer items-center space-x-2 p-1 pl-2 hover:rounded-lg hover:bg-state-base-hover', className)}>
{children}
</div>
}

View File

@@ -1,6 +1,6 @@
import type { HTMLProps, PropsWithChildren } from 'react'
import { RiArrowRightUpLine } from '@remixicon/react'
import { cn } from '@/utils/classnames'
import classNames from '@/utils/classnames'
export type SuggestedActionProps = PropsWithChildren<HTMLProps<HTMLAnchorElement> & {
icon?: React.ReactNode
@@ -19,9 +19,11 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, .
href={disabled ? undefined : link}
target='_blank'
rel='noreferrer'
className={cn('flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
className={classNames(
'flex items-center justify-start gap-2 rounded-lg bg-background-section-burn px-2.5 py-2 text-text-secondary transition-colors [&:not(:first-child)]:mt-1',
disabled ? 'cursor-not-allowed opacity-30 shadow-xs' : 'cursor-pointer text-text-secondary hover:bg-state-accent-hover hover:text-text-accent',
className)}
className,
)}
onClick={handleClick}
{...props}
>

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC, ReactNode } from 'react'
import React from 'react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
export type IFeaturePanelProps = {
className?: string

View File

@@ -6,7 +6,7 @@ import {
RiAddLine,
RiEditLine,
} from '@remixicon/react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { noop } from 'lodash-es'
export type IOperationBtnProps = {

View File

@@ -14,7 +14,7 @@ import s from './style.module.css'
import MessageTypeSelector from './message-type-selector'
import ConfirmAddVar from './confirm-add-var'
import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import type { PromptRole, PromptVariable } from '@/models/debug'
import {
Copy,

View File

@@ -2,7 +2,7 @@
import type { FC } from 'react'
import React from 'react'
import { useBoolean, useClickAway } from 'ahooks'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { PromptRole } from '@/models/debug'
import { ChevronSelectorVertical } from '@/app/components/base/icons/src/vender/line/arrows'
type Props = {

View File

@@ -2,7 +2,7 @@
import React, { useCallback, useEffect, useState } from 'react'
import type { FC } from 'react'
import { useDebounceFn } from 'ahooks'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
type Props = {
className?: string

View File

@@ -7,7 +7,7 @@ import { produce } from 'immer'
import { useContext } from 'use-context-selector'
import ConfirmAddVar from './confirm-add-var'
import PromptEditorHeightResizeWrap from './prompt-editor-height-resize-wrap'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import type { PromptVariable } from '@/models/debug'
import Tooltip from '@/app/components/base/tooltip'
import { AppModeEnum } from '@/types/app'

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import { useTranslation } from 'react-i18next'
type Props = {

View File

@@ -2,6 +2,7 @@
import type { FC } from 'react'
import React, { useState } from 'react'
import { ChevronDownIcon } from '@heroicons/react/20/solid'
import classNames from '@/utils/classnames'
import {
PortalToFollowElem,
PortalToFollowElemContent,
@@ -9,7 +10,7 @@ import {
} from '@/app/components/base/portal-to-follow-elem'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
import type { InputVarType } from '@/app/components/workflow/types'
import { cn } from '@/utils/classnames'
import cn from '@/utils/classnames'
import Badge from '@/app/components/base/badge'
import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils'
@@ -46,7 +47,7 @@ const TypeSelector: FC<Props> = ({
>
<PortalToFollowElemTrigger onClick={() => !readonly && setOpen(v => !v)} className='w-full'>
<div
className={cn(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
className={classNames(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
title={selectedItem?.name}
>
<div className='flex items-center'>
@@ -68,7 +69,7 @@ const TypeSelector: FC<Props> = ({
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[61]'>
<div
className={cn('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
className={classNames('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-none sm:text-sm', popupInnerClassName)}
>
{items.map((item: Item) => (
<div

Some files were not shown because too many files have changed in this diff Show More