mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 22:52:01 +00:00
Compare commits
20 Commits
3-16-rever
...
1.13.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8560bacb3 | ||
|
|
0f1b8bf5f9 | ||
|
|
652211ad96 | ||
|
|
c049249bc1 | ||
|
|
138083dfc8 | ||
|
|
d1961c261e | ||
|
|
a717519822 | ||
|
|
a592c53573 | ||
|
|
239e09473e | ||
|
|
18ff5d9288 | ||
|
|
18af5fc8c7 | ||
|
|
9e2123c655 | ||
|
|
7e34faaf51 | ||
|
|
569748189e | ||
|
|
f198f5b0ab | ||
|
|
49e0e1b939 | ||
|
|
f886f11094 | ||
|
|
fa82a0f708 | ||
|
|
0a3275fbe8 | ||
|
|
e445f69604 |
324
.github/workflows/web-tests.yml
vendored
324
.github/workflows/web-tests.yml
vendored
@@ -89,14 +89,24 @@ jobs:
|
||||
- name: Merge reports
|
||||
run: vp test --merge-reports --reporter=json --reporter=agent --coverage
|
||||
|
||||
- name: Check app/components diff coverage
|
||||
- name: Report app/components baseline coverage
|
||||
run: node ./scripts/report-components-coverage-baseline.mjs
|
||||
|
||||
- name: Report app/components test touch
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/report-components-test-touch.mjs
|
||||
|
||||
- name: Check app/components pure diff coverage
|
||||
env:
|
||||
BASE_SHA: ${{ inputs.base_sha }}
|
||||
DIFF_RANGE_MODE: ${{ inputs.diff_range_mode }}
|
||||
HEAD_SHA: ${{ inputs.head_sha }}
|
||||
run: node ./scripts/check-components-diff-coverage.mjs
|
||||
|
||||
- name: Coverage Summary
|
||||
- name: Check Coverage Summary
|
||||
if: always()
|
||||
id: coverage-summary
|
||||
run: |
|
||||
@@ -105,313 +115,15 @@ jobs:
|
||||
COVERAGE_FILE="coverage/coverage-final.json"
|
||||
COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json"
|
||||
|
||||
if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
if [ -f "$COVERAGE_FILE" ] || [ -f "$COVERAGE_SUMMARY_FILE" ]; then
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "has_coverage=true" >> "$GITHUB_OUTPUT"
|
||||
|
||||
node <<'NODE' >> "$GITHUB_STEP_SUMMARY"
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
let libCoverage = null;
|
||||
|
||||
try {
|
||||
libCoverage = require('istanbul-lib-coverage');
|
||||
} catch (error) {
|
||||
libCoverage = null;
|
||||
}
|
||||
|
||||
const summaryPath = path.join('coverage', 'coverage-summary.json');
|
||||
const finalPath = path.join('coverage', 'coverage-final.json');
|
||||
|
||||
const hasSummary = fs.existsSync(summaryPath);
|
||||
const hasFinal = fs.existsSync(finalPath);
|
||||
|
||||
if (!hasSummary && !hasFinal) {
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('No coverage data found.');
|
||||
process.exit(0);
|
||||
}
|
||||
|
||||
const summary = hasSummary
|
||||
? JSON.parse(fs.readFileSync(summaryPath, 'utf8'))
|
||||
: null;
|
||||
const coverage = hasFinal
|
||||
? JSON.parse(fs.readFileSync(finalPath, 'utf8'))
|
||||
: null;
|
||||
|
||||
const getLineCoverageFromStatements = (statementMap, statementHits) => {
|
||||
const lineHits = {};
|
||||
|
||||
if (!statementMap || !statementHits) {
|
||||
return lineHits;
|
||||
}
|
||||
|
||||
Object.entries(statementMap).forEach(([key, statement]) => {
|
||||
const line = statement?.start?.line;
|
||||
if (!line) {
|
||||
return;
|
||||
}
|
||||
const hits = statementHits[key] ?? 0;
|
||||
const previous = lineHits[line];
|
||||
lineHits[line] = previous === undefined ? hits : Math.max(previous, hits);
|
||||
});
|
||||
|
||||
return lineHits;
|
||||
};
|
||||
|
||||
const getFileCoverage = (entry) => (
|
||||
libCoverage ? libCoverage.createFileCoverage(entry) : null
|
||||
);
|
||||
|
||||
const getLineHits = (entry, fileCoverage) => {
|
||||
const lineHits = entry.l ?? {};
|
||||
if (Object.keys(lineHits).length > 0) {
|
||||
return lineHits;
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getLineCoverage();
|
||||
}
|
||||
return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {});
|
||||
};
|
||||
|
||||
const getUncoveredLines = (entry, fileCoverage, lineHits) => {
|
||||
if (lineHits && Object.keys(lineHits).length > 0) {
|
||||
return Object.entries(lineHits)
|
||||
.filter(([, count]) => count === 0)
|
||||
.map(([line]) => Number(line))
|
||||
.sort((a, b) => a - b);
|
||||
}
|
||||
if (fileCoverage) {
|
||||
return fileCoverage.getUncoveredLines();
|
||||
}
|
||||
return [];
|
||||
};
|
||||
|
||||
const totals = {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
};
|
||||
const fileSummaries = [];
|
||||
|
||||
if (summary) {
|
||||
const totalEntry = summary.total ?? {};
|
||||
['lines', 'statements', 'branches', 'functions'].forEach((key) => {
|
||||
if (totalEntry[key]) {
|
||||
totals[key].covered = totalEntry[key].covered ?? 0;
|
||||
totals[key].total = totalEntry[key].total ?? 0;
|
||||
}
|
||||
});
|
||||
|
||||
Object.entries(summary)
|
||||
.filter(([file]) => file !== 'total')
|
||||
.forEach(([file, data]) => {
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: data.lines?.pct ?? data.statements?.pct ?? 0,
|
||||
lines: {
|
||||
covered: data.lines?.covered ?? 0,
|
||||
total: data.lines?.total ?? 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
} else if (coverage) {
|
||||
Object.entries(coverage).forEach(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
totals.lines.total += lineTotal;
|
||||
totals.lines.covered += lineCovered;
|
||||
totals.statements.total += statementTotal;
|
||||
totals.statements.covered += statementCovered;
|
||||
totals.branches.total += branchTotal;
|
||||
totals.branches.covered += branchCovered;
|
||||
totals.functions.total += functionTotal;
|
||||
totals.functions.covered += functionCovered;
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0);
|
||||
|
||||
fileSummaries.push({
|
||||
file,
|
||||
pct: pct(lineCovered || statementCovered, lineTotal || statementTotal),
|
||||
lines: {
|
||||
covered: lineCovered || statementCovered,
|
||||
total: lineTotal || statementTotal,
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00');
|
||||
|
||||
console.log('### Test Coverage Summary :test_tube:');
|
||||
console.log('');
|
||||
console.log('| Metric | Coverage | Covered / Total |');
|
||||
console.log('|--------|----------|-----------------|');
|
||||
console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`);
|
||||
console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`);
|
||||
console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`);
|
||||
console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`);
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>File coverage (lowest lines first)</summary>');
|
||||
console.log('');
|
||||
console.log('```');
|
||||
fileSummaries
|
||||
.sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total))
|
||||
.slice(0, 25)
|
||||
.forEach(({ file, pct, lines }) => {
|
||||
console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`);
|
||||
});
|
||||
console.log('```');
|
||||
console.log('</details>');
|
||||
|
||||
if (coverage) {
|
||||
const pctValue = (covered, tot) => {
|
||||
if (tot === 0) {
|
||||
return '0';
|
||||
}
|
||||
return ((covered / tot) * 100)
|
||||
.toFixed(2)
|
||||
.replace(/\.?0+$/, '');
|
||||
};
|
||||
|
||||
const formatLineRanges = (lines) => {
|
||||
if (lines.length === 0) {
|
||||
return '';
|
||||
}
|
||||
const ranges = [];
|
||||
let start = lines[0];
|
||||
let end = lines[0];
|
||||
|
||||
for (let i = 1; i < lines.length; i += 1) {
|
||||
const current = lines[i];
|
||||
if (current === end + 1) {
|
||||
end = current;
|
||||
continue;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
start = current;
|
||||
end = current;
|
||||
}
|
||||
ranges.push(start === end ? `${start}` : `${start}-${end}`);
|
||||
return ranges.join(',');
|
||||
};
|
||||
|
||||
const tableTotals = {
|
||||
statements: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
lines: { covered: 0, total: 0 },
|
||||
};
|
||||
const tableRows = Object.entries(coverage)
|
||||
.map(([file, entry]) => {
|
||||
const fileCoverage = getFileCoverage(entry);
|
||||
const lineHits = getLineHits(entry, fileCoverage);
|
||||
const statementHits = entry.s ?? {};
|
||||
const branchHits = entry.b ?? {};
|
||||
const functionHits = entry.f ?? {};
|
||||
|
||||
const lineTotal = Object.keys(lineHits).length;
|
||||
const lineCovered = Object.values(lineHits).filter((n) => n > 0).length;
|
||||
const statementTotal = Object.keys(statementHits).length;
|
||||
const statementCovered = Object.values(statementHits).filter((n) => n > 0).length;
|
||||
const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0);
|
||||
const branchCovered = Object.values(branchHits).reduce(
|
||||
(acc, branches) => acc + branches.filter((n) => n > 0).length,
|
||||
0,
|
||||
);
|
||||
const functionTotal = Object.keys(functionHits).length;
|
||||
const functionCovered = Object.values(functionHits).filter((n) => n > 0).length;
|
||||
|
||||
tableTotals.lines.total += lineTotal;
|
||||
tableTotals.lines.covered += lineCovered;
|
||||
tableTotals.statements.total += statementTotal;
|
||||
tableTotals.statements.covered += statementCovered;
|
||||
tableTotals.branches.total += branchTotal;
|
||||
tableTotals.branches.covered += branchCovered;
|
||||
tableTotals.functions.total += functionTotal;
|
||||
tableTotals.functions.covered += functionCovered;
|
||||
|
||||
const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits);
|
||||
|
||||
const filePath = entry.path ?? file;
|
||||
const relativePath = path.isAbsolute(filePath)
|
||||
? path.relative(process.cwd(), filePath)
|
||||
: filePath;
|
||||
|
||||
return {
|
||||
file: relativePath || file,
|
||||
statements: pctValue(statementCovered, statementTotal),
|
||||
branches: pctValue(branchCovered, branchTotal),
|
||||
functions: pctValue(functionCovered, functionTotal),
|
||||
lines: pctValue(lineCovered, lineTotal),
|
||||
uncovered: formatLineRanges(uncoveredLines),
|
||||
};
|
||||
})
|
||||
.sort((a, b) => a.file.localeCompare(b.file));
|
||||
|
||||
const columns = [
|
||||
{ key: 'file', header: 'File', align: 'left' },
|
||||
{ key: 'statements', header: '% Stmts', align: 'right' },
|
||||
{ key: 'branches', header: '% Branch', align: 'right' },
|
||||
{ key: 'functions', header: '% Funcs', align: 'right' },
|
||||
{ key: 'lines', header: '% Lines', align: 'right' },
|
||||
{ key: 'uncovered', header: 'Uncovered Line #s', align: 'left' },
|
||||
];
|
||||
|
||||
const allFilesRow = {
|
||||
file: 'All files',
|
||||
statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total),
|
||||
branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total),
|
||||
functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total),
|
||||
lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total),
|
||||
uncovered: '',
|
||||
};
|
||||
|
||||
const rowsForOutput = [allFilesRow, ...tableRows];
|
||||
const formatRow = (row) => `| ${columns
|
||||
.map(({ key }) => String(row[key] ?? ''))
|
||||
.join(' | ')} |`;
|
||||
const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`;
|
||||
const dividerRow = `| ${columns
|
||||
.map(({ align }) => (align === 'right' ? '---:' : ':---'))
|
||||
.join(' | ')} |`;
|
||||
|
||||
console.log('');
|
||||
console.log('<details><summary>Vitest coverage table</summary>');
|
||||
console.log('');
|
||||
console.log(headerRow);
|
||||
console.log(dividerRow);
|
||||
rowsForOutput.forEach((row) => console.log(formatRow(row)));
|
||||
console.log('</details>');
|
||||
}
|
||||
NODE
|
||||
echo "has_coverage=false" >> "$GITHUB_OUTPUT"
|
||||
echo "### 🚨 app/components Diff Coverage" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "" >> "$GITHUB_STEP_SUMMARY"
|
||||
echo "Coverage artifacts not found. Ensure Vitest merge reports ran with coverage enabled." >> "$GITHUB_STEP_SUMMARY"
|
||||
|
||||
- name: Upload Coverage Artifact
|
||||
if: steps.coverage-summary.outputs.has_coverage == 'true'
|
||||
|
||||
@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
@@ -88,6 +88,8 @@ def clean_workflow_runs(
|
||||
"""
|
||||
Clean workflow runs and related workflow data for free tenants.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
@@ -104,16 +106,27 @@ def clean_workflow_runs(
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
task_label = f"{from_days_ago}to{to_days_ago}"
|
||||
elif start_from is None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
).run()
|
||||
try:
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
).run()
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
end_time = datetime.datetime.now(datetime.UTC)
|
||||
elapsed = end_time - start_time
|
||||
@@ -659,6 +672,8 @@ def clean_expired_messages(
|
||||
"""
|
||||
Clean expired messages and related data for tenants based on clean policy.
|
||||
"""
|
||||
from extensions.otel.runtime import flush_telemetry
|
||||
|
||||
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
|
||||
|
||||
start_at = time.perf_counter()
|
||||
@@ -698,6 +713,13 @@ def clean_expired_messages(
|
||||
# NOTE: graceful_period will be ignored when billing is disabled.
|
||||
policy = create_message_clean_policy(graceful_period_days=graceful_period)
|
||||
|
||||
if from_days_ago is not None and before_days is not None:
|
||||
task_label = f"{from_days_ago}to{before_days}"
|
||||
elif start_from is None and before_days is not None:
|
||||
task_label = f"before-{before_days}"
|
||||
else:
|
||||
task_label = "custom"
|
||||
|
||||
# Create and run the cleanup service
|
||||
if abs_mode:
|
||||
assert start_from is not None
|
||||
@@ -708,6 +730,7 @@ def clean_expired_messages(
|
||||
end_before=end_before,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
elif from_days_ago is None:
|
||||
assert before_days is not None
|
||||
@@ -716,6 +739,7 @@ def clean_expired_messages(
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
else:
|
||||
assert before_days is not None
|
||||
@@ -727,6 +751,7 @@ def clean_expired_messages(
|
||||
end_before=now - datetime.timedelta(days=before_days),
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
@@ -752,6 +777,8 @@ def clean_expired_messages(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
flush_telemetry()
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
12
api/configs/middleware/cache/redis_config.py
vendored
12
api/configs/middleware/cache/redis_config.py
vendored
@@ -1,4 +1,4 @@
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt
|
||||
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
@@ -116,3 +116,13 @@ class RedisConfig(BaseSettings):
|
||||
description="Maximum connections in the Redis connection pool (unset for library default)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@field_validator("REDIS_MAX_CONNECTIONS", mode="before")
|
||||
@classmethod
|
||||
def _empty_string_to_none_for_max_conns(cls, v):
|
||||
"""Allow empty string in env/.env to mean 'unset' (None)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str) and v.strip() == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@@ -473,9 +473,21 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
@@ -37,7 +38,7 @@ from extensions.ext_storage import storage
|
||||
from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
@@ -265,7 +266,7 @@ class IndexingRunner:
|
||||
self,
|
||||
tenant_id: str,
|
||||
extract_settings: list[ExtractSetting],
|
||||
tmp_processing_rule: dict,
|
||||
tmp_processing_rule: Mapping[str, Any],
|
||||
doc_form: str | None = None,
|
||||
doc_language: str = "English",
|
||||
dataset_id: str | None = None,
|
||||
@@ -376,7 +377,7 @@ class IndexingRunner:
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
|
||||
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any]
|
||||
) -> list[Document]:
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
text_docs = []
|
||||
@@ -543,6 +544,7 @@ class IndexingRunner:
|
||||
"""
|
||||
Clean the document text according to the processing rules.
|
||||
"""
|
||||
rules: AutomaticRulesConfig | dict[str, Any]
|
||||
if processing_rule.mode == "automatic":
|
||||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
@@ -756,7 +758,7 @@ class IndexingRunner:
|
||||
dataset: Dataset,
|
||||
text_docs: list[Document],
|
||||
doc_language: str,
|
||||
process_rule: dict,
|
||||
process_rule: Mapping[str, Any],
|
||||
current_user: Account | None = None,
|
||||
) -> list[Document]:
|
||||
# get embedding model instance
|
||||
|
||||
@@ -196,6 +196,8 @@ class ProviderManager:
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides integration with Weaviate vector database for storing and r
|
||||
document embeddings used in retrieval-augmented generation workflows.
|
||||
"""
|
||||
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
@@ -37,6 +38,32 @@ _weaviate_client: weaviate.WeaviateClient | None = None
|
||||
_weaviate_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def _shutdown_weaviate_client() -> None:
|
||||
"""
|
||||
Best-effort shutdown hook to close the module-level Weaviate client.
|
||||
|
||||
This is registered with atexit so that HTTP/gRPC resources are released
|
||||
when the Python interpreter exits.
|
||||
"""
|
||||
global _weaviate_client
|
||||
|
||||
# Ensure thread-safety when accessing the shared client instance
|
||||
with _weaviate_client_lock:
|
||||
client = _weaviate_client
|
||||
_weaviate_client = None
|
||||
|
||||
if client is not None:
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
# Best-effort cleanup; log at debug level and ignore errors.
|
||||
logger.debug("Failed to close Weaviate client during shutdown", exc_info=True)
|
||||
|
||||
|
||||
# Register the shutdown hook once per process.
|
||||
atexit.register(_shutdown_weaviate_client)
|
||||
|
||||
|
||||
class WeaviateConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for Weaviate connection settings.
|
||||
@@ -85,18 +112,6 @@ class WeaviateVector(BaseVector):
|
||||
self._client = self._init_client(config)
|
||||
self._attributes = attributes
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor to properly close the Weaviate client connection.
|
||||
Prevents connection leaks and resource warnings.
|
||||
"""
|
||||
if hasattr(self, "_client") and self._client is not None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception as e:
|
||||
# Ignore errors during cleanup as object is being destroyed
|
||||
logger.warning("Error closing Weaviate client %s", e, exc_info=True)
|
||||
|
||||
def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient:
|
||||
"""
|
||||
Initializes and returns a connected Weaviate client.
|
||||
|
||||
@@ -6,6 +6,5 @@ of responses based on upstream node outputs and constants.
|
||||
"""
|
||||
|
||||
from .coordinator import ResponseStreamCoordinator
|
||||
from .session import RESPONSE_SESSION_NODE_TYPES
|
||||
|
||||
__all__ = ["RESPONSE_SESSION_NODE_TYPES", "ResponseStreamCoordinator"]
|
||||
__all__ = ["ResponseStreamCoordinator"]
|
||||
|
||||
@@ -3,10 +3,6 @@ Internal response session management for response coordinator.
|
||||
|
||||
This module contains the private ResponseSession class used internally
|
||||
by ResponseStreamCoordinator to manage streaming sessions.
|
||||
|
||||
`RESPONSE_SESSION_NODE_TYPES` is intentionally mutable so downstream applications
|
||||
can opt additional response-capable node types into session creation without
|
||||
patching the coordinator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,7 +10,6 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, cast
|
||||
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
from dify_graph.nodes.base.template import Template
|
||||
from dify_graph.runtime.graph_runtime_state import NodeProtocol
|
||||
|
||||
@@ -25,12 +20,6 @@ class _ResponseSessionNodeProtocol(NodeProtocol, Protocol):
|
||||
def get_streaming_template(self) -> Template: ...
|
||||
|
||||
|
||||
RESPONSE_SESSION_NODE_TYPES: list[NodeType] = [
|
||||
BuiltinNodeTypes.ANSWER,
|
||||
BuiltinNodeTypes.END,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResponseSession:
|
||||
"""
|
||||
@@ -49,8 +38,8 @@ class ResponseSession:
|
||||
Create a ResponseSession from a response-capable node.
|
||||
|
||||
The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer.
|
||||
At runtime this must be a node whose `node_type` is listed in `RESPONSE_SESSION_NODE_TYPES`
|
||||
and which implements `get_streaming_template()`.
|
||||
At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which
|
||||
graph nodes should be treated as response-capable before they reach this factory.
|
||||
|
||||
Args:
|
||||
node: Node from the materialized workflow graph.
|
||||
@@ -59,15 +48,8 @@ class ResponseSession:
|
||||
ResponseSession configured with the node's streaming template
|
||||
|
||||
Raises:
|
||||
TypeError: If node is not a supported response node type.
|
||||
TypeError: If node does not implement the response-session streaming contract.
|
||||
"""
|
||||
if node.node_type not in RESPONSE_SESSION_NODE_TYPES:
|
||||
supported_node_types = ", ".join(RESPONSE_SESSION_NODE_TYPES)
|
||||
raise TypeError(
|
||||
"ResponseSession.from_node only supports node types in "
|
||||
f"RESPONSE_SESSION_NODE_TYPES: {supported_node_types}"
|
||||
)
|
||||
|
||||
response_node = cast(_ResponseSessionNodeProtocol, node)
|
||||
try:
|
||||
template = response_node.get_streaming_template()
|
||||
|
||||
@@ -101,7 +101,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
max_retries=0,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@@ -256,9 +256,13 @@ def fetch_prompt_messages(
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
if not prompt_message_content:
|
||||
continue
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from dify_graph.variables.segments import Segment
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: SegmentType):
|
||||
def get_zero_value(t: SegmentType) -> Segment:
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
@@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
class SupportsIncludeRouter(Protocol):
|
||||
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
@@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Union
|
||||
|
||||
from celery.signals import worker_init
|
||||
from flask_login import user_loaded_from_request, user_logged_in
|
||||
from opentelemetry import trace
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.propagate import set_global_textmap
|
||||
from opentelemetry.propagators.b3 import B3MultiFormat
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
@@ -31,9 +31,29 @@ def setup_context_propagation() -> None:
|
||||
|
||||
|
||||
def shutdown_tracer() -> None:
|
||||
flush_telemetry()
|
||||
|
||||
|
||||
def flush_telemetry() -> None:
|
||||
"""
|
||||
Best-effort flush for telemetry providers.
|
||||
|
||||
This is mainly used by short-lived command processes (e.g. Kubernetes CronJob)
|
||||
so counters/histograms are exported before the process exits.
|
||||
"""
|
||||
provider = trace.get_tracer_provider()
|
||||
if hasattr(provider, "force_flush"):
|
||||
provider.force_flush()
|
||||
try:
|
||||
provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush trace provider")
|
||||
|
||||
metric_provider = metrics.get_meter_provider()
|
||||
if hasattr(metric_provider, "force_flush"):
|
||||
try:
|
||||
metric_provider.force_flush()
|
||||
except Exception:
|
||||
logger.exception("otel: failed to flush metric provider")
|
||||
|
||||
|
||||
def is_celery_worker():
|
||||
|
||||
@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
@@ -296,13 +296,11 @@ def segment_to_variable(
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
return variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value_type=segment.value_type,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
)
|
||||
|
||||
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stream_with_request_context(response: object) -> Any:
|
||||
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||
return cast(Any, stream_with_context)(response)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
@@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
def compact_generate_response(
|
||||
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
def generate() -> Generator[str, None, None]:
|
||||
yield from stream_response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
def length_prefixed_response(
|
||||
magic_number: int,
|
||||
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
@@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
@@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[bytes, None, None]:
|
||||
for chunk in stream_response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
check_csrf_token(request, user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
def cached_import(module_path: str, class_name: str) -> Any:
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
@@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = sys.modules.get(module_path)
|
||||
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||
if module is None or getattr(spec, "_initializing", False):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str):
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
@@ -1,7 +1,48 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
sub: str
|
||||
email: str
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -11,26 +52,38 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
class OAuth:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
@@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return _json_object(response)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
|
||||
@@ -1,25 +1,57 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
class NotionSourceInfo(TypedDict):
|
||||
workspace_name: str | None
|
||||
workspace_icon: str | None
|
||||
workspace_id: str | None
|
||||
pages: list[NotionPageSummary]
|
||||
total: int
|
||||
|
||||
|
||||
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
@@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
@@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
@@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
def notion_workspace_name(self, access_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_source_info(
|
||||
*,
|
||||
workspace_name: str | None,
|
||||
workspace_icon: str | None,
|
||||
workspace_id: str | None,
|
||||
pages: list[NotionPageSummary],
|
||||
) -> NotionSourceInfo:
|
||||
return {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -37,6 +37,61 @@ from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adj
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreProcessingRuleItem(TypedDict):
|
||||
id: str
|
||||
enabled: bool
|
||||
|
||||
|
||||
class SegmentationConfig(TypedDict):
|
||||
delimiter: str
|
||||
max_tokens: int
|
||||
chunk_overlap: int
|
||||
|
||||
|
||||
class AutomaticRulesConfig(TypedDict):
|
||||
pre_processing_rules: list[PreProcessingRuleItem]
|
||||
segmentation: SegmentationConfig
|
||||
|
||||
|
||||
class ProcessRuleDict(TypedDict):
|
||||
id: str
|
||||
dataset_id: str
|
||||
mode: str
|
||||
rules: dict[str, Any] | None
|
||||
|
||||
|
||||
class DocMetadataDetailItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
value: Any
|
||||
|
||||
|
||||
class AttachmentItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
size: int
|
||||
extension: str
|
||||
mime_type: str
|
||||
source_url: str
|
||||
|
||||
|
||||
class DatasetBindingItem(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ExternalKnowledgeApiDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
name: str
|
||||
description: str
|
||||
settings: dict[str, Any] | None
|
||||
dataset_bindings: list[DatasetBindingItem]
|
||||
created_by: str
|
||||
created_at: str
|
||||
|
||||
|
||||
class DatasetPermissionEnum(enum.StrEnum):
|
||||
ONLY_ME = "only_me"
|
||||
ALL_TEAM = "all_team_members"
|
||||
@@ -334,7 +389,7 @@ class DatasetProcessRule(Base): # bug
|
||||
|
||||
MODES = ["automatic", "custom", "hierarchical"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
AUTOMATIC_RULES: dict[str, Any] = {
|
||||
AUTOMATIC_RULES: AutomaticRulesConfig = {
|
||||
"pre_processing_rules": [
|
||||
{"id": "remove_extra_spaces", "enabled": True},
|
||||
{"id": "remove_urls_emails", "enabled": False},
|
||||
@@ -342,7 +397,7 @@ class DatasetProcessRule(Base): # bug
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ProcessRuleDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"dataset_id": self.dataset_id,
|
||||
@@ -531,7 +586,7 @@ class Document(Base):
|
||||
return self.updated_at
|
||||
|
||||
@property
|
||||
def doc_metadata_details(self) -> list[dict[str, Any]] | None:
|
||||
def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None:
|
||||
if self.doc_metadata:
|
||||
document_metadatas = (
|
||||
db.session.query(DatasetMetadata)
|
||||
@@ -541,9 +596,9 @@ class Document(Base):
|
||||
)
|
||||
.all()
|
||||
)
|
||||
metadata_list: list[dict[str, Any]] = []
|
||||
metadata_list: list[DocMetadataDetailItem] = []
|
||||
for metadata in document_metadatas:
|
||||
metadata_dict: dict[str, Any] = {
|
||||
metadata_dict: DocMetadataDetailItem = {
|
||||
"id": metadata.id,
|
||||
"name": metadata.name,
|
||||
"type": metadata.type,
|
||||
@@ -557,13 +612,13 @@ class Document(Base):
|
||||
return None
|
||||
|
||||
@property
|
||||
def process_rule_dict(self) -> dict[str, Any] | None:
|
||||
def process_rule_dict(self) -> ProcessRuleDict | None:
|
||||
if self.dataset_process_rule_id and self.dataset_process_rule:
|
||||
return self.dataset_process_rule.to_dict()
|
||||
return None
|
||||
|
||||
def get_built_in_fields(self) -> list[dict[str, Any]]:
|
||||
built_in_fields: list[dict[str, Any]] = []
|
||||
def get_built_in_fields(self) -> list[DocMetadataDetailItem]:
|
||||
built_in_fields: list[DocMetadataDetailItem] = []
|
||||
built_in_fields.append(
|
||||
{
|
||||
"id": "built-in",
|
||||
@@ -877,7 +932,7 @@ class DocumentSegment(Base):
|
||||
return text
|
||||
|
||||
@property
|
||||
def attachments(self) -> list[dict[str, Any]]:
|
||||
def attachments(self) -> list[AttachmentItem]:
|
||||
# Use JOIN to fetch attachments in a single query instead of two separate queries
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
@@ -891,7 +946,7 @@ class DocumentSegment(Base):
|
||||
).all()
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
attachment_list = []
|
||||
attachment_list: list[AttachmentItem] = []
|
||||
for _, attachment in attachments_with_bindings:
|
||||
upload_file_id = attachment.id
|
||||
nonce = os.urandom(16).hex()
|
||||
@@ -1261,7 +1316,7 @@ class ExternalKnowledgeApis(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> ExternalKnowledgeApiDict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
@@ -1281,13 +1336,13 @@ class ExternalKnowledgeApis(TypeBase):
|
||||
return None
|
||||
|
||||
@property
|
||||
def dataset_bindings(self) -> list[dict[str, Any]]:
|
||||
def dataset_bindings(self) -> list[DatasetBindingItem]:
|
||||
external_knowledge_bindings = db.session.scalars(
|
||||
select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
|
||||
).all()
|
||||
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
|
||||
datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all()
|
||||
dataset_bindings: list[dict[str, Any]] = []
|
||||
dataset_bindings: list[DatasetBindingItem] = []
|
||||
for dataset in datasets:
|
||||
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
|
||||
|
||||
|
||||
@@ -11,6 +11,13 @@ class CreatorUserRole(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
END_USER = "end_user"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value == "end-user":
|
||||
return cls.END_USER
|
||||
else:
|
||||
return super()._missing_(value)
|
||||
|
||||
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypedDict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
workflow_run_id: str | None
|
||||
root_node_id: str | None
|
||||
trigger_metadata: Any
|
||||
trigger_type: str
|
||||
trigger_data: Any
|
||||
inputs: Any
|
||||
outputs: Any
|
||||
status: str
|
||||
error: str | None
|
||||
queue_name: str
|
||||
celery_task_id: str | None
|
||||
retry_count: int
|
||||
elapsed_time: float | None
|
||||
total_tokens: int | None
|
||||
created_by_role: str
|
||||
created_by: str
|
||||
created_at: str | None
|
||||
triggered_at: str | None
|
||||
finished_at: str | None
|
||||
|
||||
|
||||
class WorkflowSchedulePlanDict(TypedDict):
|
||||
id: str
|
||||
app_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
cron_expression: str
|
||||
timezone: str
|
||||
next_run_at: str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class TriggerSubscription(TypeBase):
|
||||
"""
|
||||
@@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
@@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowTriggerLogDict:
|
||||
"""Convert to dictionary for API responses"""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase):
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
def to_dict(self) -> WorkflowSchedulePlanDict:
|
||||
"""Convert to dictionary representation"""
|
||||
return {
|
||||
"id": self.id,
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -59,6 +59,25 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
features: dict[str, Any]
|
||||
environment_variables: list[dict[str, Any]]
|
||||
conversation_variables: list[dict[str, Any]]
|
||||
rag_pipeline_variables: list[dict[str, Any]]
|
||||
|
||||
|
||||
class WorkflowRunSummaryDict(TypedDict):
|
||||
id: str
|
||||
status: str
|
||||
triggered_from: str
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class WorkflowType(StrEnum):
|
||||
"""
|
||||
@@ -389,7 +408,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@@ -432,17 +451,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@@ -502,14 +517,14 @@ class Workflow(Base): # bug
|
||||
)
|
||||
self._environment_variables = environment_variables_json
|
||||
|
||||
def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]:
|
||||
def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict:
|
||||
environment_variables = list(self.environment_variables)
|
||||
environment_variables = [
|
||||
v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""})
|
||||
for v in environment_variables
|
||||
]
|
||||
|
||||
result = {
|
||||
result: WorkflowContentDict = {
|
||||
"graph": self.graph_dict,
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
@@ -520,11 +535,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@@ -536,19 +547,20 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -786,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
@@ -1231,7 +1235,7 @@ class WorkflowArchiveLog(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def workflow_run_summary(self) -> dict[str, Any]:
|
||||
def workflow_run_summary(self) -> WorkflowRunSummaryDict:
|
||||
return {
|
||||
"id": self.workflow_run_id,
|
||||
"status": self.run_status,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.13.0"
|
||||
version = "1.13.2"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
|
||||
dependencies = [
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
@@ -138,8 +137,6 @@ dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
@@ -156,19 +153,7 @@ extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
|
||||
@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||
id: str
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
status: WorkflowNodeExecutionStatus
|
||||
elapsed_time: float | None
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
execution_metadata: str | None
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
@@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
@@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
|
||||
@@ -18,7 +18,7 @@ from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.enums import CreatorUserRole, WorkflowTriggerStatus
|
||||
from models.model import App, EndUser
|
||||
from models.trigger import WorkflowTriggerLog
|
||||
from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict
|
||||
from models.workflow import Workflow
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
@@ -224,7 +224,9 @@ class AsyncWorkflowService:
|
||||
return cls.trigger_workflow_async(session, user, trigger_data)
|
||||
|
||||
@classmethod
|
||||
def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None:
|
||||
def get_trigger_log(
|
||||
cls, workflow_trigger_log_id: str, tenant_id: str | None = None
|
||||
) -> WorkflowTriggerLogDict | None:
|
||||
"""
|
||||
Get trigger log by ID
|
||||
|
||||
@@ -247,7 +249,7 @@ class AsyncWorkflowService:
|
||||
@classmethod
|
||||
def get_recent_logs(
|
||||
cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get recent trigger logs
|
||||
|
||||
@@ -272,7 +274,7 @@ class AsyncWorkflowService:
|
||||
@classmethod
|
||||
def get_failed_logs_for_retry(
|
||||
cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100
|
||||
) -> list[dict[str, Any]]:
|
||||
) -> list[WorkflowTriggerLogDict]:
|
||||
"""
|
||||
Get failed logs eligible for retry
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@@ -534,6 +534,13 @@ class PluginService:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credential_ids = session.scalars(
|
||||
select(ProviderCredential.id).where(
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.model import (
|
||||
@@ -33,6 +33,131 @@ from services.retention.conversation.messages_clean_policy import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class MessagesCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs.
|
||||
|
||||
We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain
|
||||
dashboard-friendly for long-running CronJob executions.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_messages_scanned_total: "Counter | None"
|
||||
_messages_filtered_total: "Counter | None"
|
||||
_messages_deleted_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._messages_scanned_total = None
|
||||
self._messages_filtered_total = None
|
||||
self._messages_deleted_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "messages_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("messages_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"messages_cleanup_jobs_total",
|
||||
description="Total number of expired message cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"messages_cleanup_batches_total",
|
||||
description="Total number of message cleanup batches processed.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._messages_scanned_total = meter.create_counter(
|
||||
"messages_cleanup_scanned_messages_total",
|
||||
description="Total messages scanned by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_filtered_total = meter.create_counter(
|
||||
"messages_cleanup_filtered_messages_total",
|
||||
description="Total messages selected by cleanup policy.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._messages_deleted_total = meter.create_counter(
|
||||
"messages_cleanup_deleted_messages_total",
|
||||
description="Total messages deleted by cleanup jobs.",
|
||||
unit="{message}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_job_duration_seconds",
|
||||
description="Duration of expired message cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"messages_cleanup_batch_duration_seconds",
|
||||
description="Duration of expired message cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("messages_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
scanned_messages: int,
|
||||
filtered_messages: int,
|
||||
deleted_messages: int,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._messages_scanned_total, scanned_messages, attributes)
|
||||
self._add(self._messages_filtered_total, filtered_messages, attributes)
|
||||
self._add(self._messages_deleted_total, deleted_messages, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class MessagesCleanService:
|
||||
"""
|
||||
Service for cleaning expired messages based on retention policies.
|
||||
@@ -48,6 +173,7 @@ class MessagesCleanService:
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the service with cleanup parameters.
|
||||
@@ -58,12 +184,18 @@ class MessagesCleanService:
|
||||
start_from: Optional start time (inclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
"""
|
||||
self._policy = policy
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._batch_size = batch_size
|
||||
self._dry_run = dry_run
|
||||
self._metrics = MessagesCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_time_range(
|
||||
@@ -73,6 +205,7 @@ class MessagesCleanService:
|
||||
end_before: datetime.datetime,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages within a specific time range.
|
||||
@@ -85,6 +218,7 @@ class MessagesCleanService:
|
||||
end_before: End time (exclusive) of the range
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@@ -112,6 +246,7 @@ class MessagesCleanService:
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -121,6 +256,7 @@ class MessagesCleanService:
|
||||
days: int = 30,
|
||||
batch_size: int = 1000,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
) -> "MessagesCleanService":
|
||||
"""
|
||||
Create a service instance for cleaning messages older than specified days.
|
||||
@@ -130,6 +266,7 @@ class MessagesCleanService:
|
||||
days: Number of days to look back from now
|
||||
batch_size: Number of messages to process per batch
|
||||
dry_run: Whether to perform a dry run (no actual deletion)
|
||||
task_label: Optional task label for retention metrics
|
||||
|
||||
Returns:
|
||||
MessagesCleanService instance
|
||||
@@ -153,7 +290,14 @@ class MessagesCleanService:
|
||||
policy.__class__.__name__,
|
||||
)
|
||||
|
||||
return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run)
|
||||
return cls(
|
||||
policy=policy,
|
||||
end_before=end_before,
|
||||
start_from=None,
|
||||
batch_size=batch_size,
|
||||
dry_run=dry_run,
|
||||
task_label=task_label,
|
||||
)
|
||||
|
||||
def run(self) -> dict[str, int]:
|
||||
"""
|
||||
@@ -162,7 +306,18 @@ class MessagesCleanService:
|
||||
Returns:
|
||||
Dict with statistics: batches, filtered_messages, total_deleted
|
||||
"""
|
||||
return self._clean_messages_by_time_range()
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
try:
|
||||
return self._clean_messages_by_time_range()
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _clean_messages_by_time_range(self) -> dict[str, int]:
|
||||
"""
|
||||
@@ -197,11 +352,14 @@ class MessagesCleanService:
|
||||
self._end_before,
|
||||
)
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
while True:
|
||||
stats["batches"] += 1
|
||||
batch_start = time.monotonic()
|
||||
batch_scanned_messages = 0
|
||||
batch_filtered_messages = 0
|
||||
batch_deleted_messages = 0
|
||||
|
||||
# Step 1: Fetch a batch of messages using cursor
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
@@ -240,9 +398,16 @@ class MessagesCleanService:
|
||||
|
||||
# Track total messages fetched across all batches
|
||||
stats["total_messages"] += len(messages)
|
||||
batch_scanned_messages = len(messages)
|
||||
|
||||
if not messages:
|
||||
logger.info("clean_messages (batch %s): no more messages to process", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
break
|
||||
|
||||
# Update cursor to the last message's (created_at, id)
|
||||
@@ -268,6 +433,12 @@ class MessagesCleanService:
|
||||
|
||||
if not apps:
|
||||
logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
# Build app_id -> tenant_id mapping
|
||||
@@ -286,9 +457,16 @@ class MessagesCleanService:
|
||||
|
||||
if not message_ids_to_delete:
|
||||
logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"])
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
stats["filtered_messages"] += len(message_ids_to_delete)
|
||||
batch_filtered_messages = len(message_ids_to_delete)
|
||||
|
||||
# Step 4: Batch delete messages and their relations
|
||||
if not self._dry_run:
|
||||
@@ -309,6 +487,7 @@ class MessagesCleanService:
|
||||
commit_ms = int((time.monotonic() - commit_start) * 1000)
|
||||
|
||||
stats["total_deleted"] += messages_deleted
|
||||
batch_deleted_messages = messages_deleted
|
||||
|
||||
logger.info(
|
||||
"clean_messages (batch %s): processed %s messages, deleted %s messages",
|
||||
@@ -343,6 +522,13 @@ class MessagesCleanService:
|
||||
for msg_id in sampled_ids:
|
||||
logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id)
|
||||
|
||||
self._metrics.record_batch(
|
||||
scanned_messages=batch_scanned_messages,
|
||||
filtered_messages=batch_filtered_messages,
|
||||
deleted_messages=batch_deleted_messages,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s",
|
||||
stats["batches"],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@@ -20,6 +20,159 @@ from services.billing_service import BillingService, SubscriptionPlan
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.metrics import Counter, Histogram
|
||||
|
||||
|
||||
class WorkflowRunCleanupMetrics:
|
||||
"""
|
||||
Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs.
|
||||
|
||||
Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status)
|
||||
to keep dashboard and alert cardinality predictable in production clusters.
|
||||
"""
|
||||
|
||||
_job_runs_total: "Counter | None"
|
||||
_batches_total: "Counter | None"
|
||||
_runs_scanned_total: "Counter | None"
|
||||
_runs_targeted_total: "Counter | None"
|
||||
_runs_deleted_total: "Counter | None"
|
||||
_runs_skipped_total: "Counter | None"
|
||||
_related_records_total: "Counter | None"
|
||||
_job_duration_seconds: "Histogram | None"
|
||||
_batch_duration_seconds: "Histogram | None"
|
||||
_base_attributes: dict[str, str]
|
||||
|
||||
def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None:
|
||||
self._job_runs_total = None
|
||||
self._batches_total = None
|
||||
self._runs_scanned_total = None
|
||||
self._runs_targeted_total = None
|
||||
self._runs_deleted_total = None
|
||||
self._runs_skipped_total = None
|
||||
self._related_records_total = None
|
||||
self._job_duration_seconds = None
|
||||
self._batch_duration_seconds = None
|
||||
self._base_attributes = {
|
||||
"job_name": "workflow_run_cleanup",
|
||||
"dry_run": str(dry_run).lower(),
|
||||
"window_mode": "between" if has_window else "before_cutoff",
|
||||
"task_label": task_label,
|
||||
}
|
||||
self._init_instruments()
|
||||
|
||||
def _init_instruments(self) -> None:
|
||||
if not dify_config.ENABLE_OTEL:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry.metrics import get_meter
|
||||
|
||||
meter = get_meter("workflow_run_cleanup", version=dify_config.project.version)
|
||||
self._job_runs_total = meter.create_counter(
|
||||
"workflow_run_cleanup_jobs_total",
|
||||
description="Total number of workflow run cleanup jobs by status.",
|
||||
unit="{job}",
|
||||
)
|
||||
self._batches_total = meter.create_counter(
|
||||
"workflow_run_cleanup_batches_total",
|
||||
description="Total number of processed cleanup batches.",
|
||||
unit="{batch}",
|
||||
)
|
||||
self._runs_scanned_total = meter.create_counter(
|
||||
"workflow_run_cleanup_scanned_runs_total",
|
||||
description="Total workflow runs scanned by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_targeted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_targeted_runs_total",
|
||||
description="Total workflow runs targeted by cleanup policy.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_deleted_total = meter.create_counter(
|
||||
"workflow_run_cleanup_deleted_runs_total",
|
||||
description="Total workflow runs deleted by cleanup jobs.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._runs_skipped_total = meter.create_counter(
|
||||
"workflow_run_cleanup_skipped_runs_total",
|
||||
description="Total workflow runs skipped because tenant is paid/unknown.",
|
||||
unit="{run}",
|
||||
)
|
||||
self._related_records_total = meter.create_counter(
|
||||
"workflow_run_cleanup_related_records_total",
|
||||
description="Total related records processed by cleanup jobs.",
|
||||
unit="{record}",
|
||||
)
|
||||
self._job_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_job_duration_seconds",
|
||||
description="Duration of workflow run cleanup jobs in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
self._batch_duration_seconds = meter.create_histogram(
|
||||
"workflow_run_cleanup_batch_duration_seconds",
|
||||
description="Duration of workflow run cleanup batch processing in seconds.",
|
||||
unit="s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments")
|
||||
|
||||
def _attrs(self, **extra: str) -> dict[str, str]:
|
||||
return {**self._base_attributes, **extra}
|
||||
|
||||
@staticmethod
|
||||
def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None:
|
||||
if not counter or value <= 0:
|
||||
return
|
||||
try:
|
||||
counter.add(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to add counter value")
|
||||
|
||||
@staticmethod
|
||||
def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None:
|
||||
if not histogram:
|
||||
return
|
||||
try:
|
||||
histogram.record(value, attributes)
|
||||
except Exception:
|
||||
logger.exception("workflow_run_cleanup_metrics: failed to record histogram value")
|
||||
|
||||
def record_batch(
|
||||
self,
|
||||
*,
|
||||
batch_rows: int,
|
||||
targeted_runs: int,
|
||||
skipped_runs: int,
|
||||
deleted_runs: int,
|
||||
related_counts: dict[str, int] | None,
|
||||
related_action: str | None,
|
||||
batch_duration_seconds: float,
|
||||
) -> None:
|
||||
attributes = self._attrs()
|
||||
self._add(self._batches_total, 1, attributes)
|
||||
self._add(self._runs_scanned_total, batch_rows, attributes)
|
||||
self._add(self._runs_targeted_total, targeted_runs, attributes)
|
||||
self._add(self._runs_skipped_total, skipped_runs, attributes)
|
||||
self._add(self._runs_deleted_total, deleted_runs, attributes)
|
||||
self._record(self._batch_duration_seconds, batch_duration_seconds, attributes)
|
||||
|
||||
if not related_counts or not related_action:
|
||||
return
|
||||
|
||||
for record_type, count in related_counts.items():
|
||||
self._add(
|
||||
self._related_records_total,
|
||||
count,
|
||||
self._attrs(action=related_action, record_type=record_type),
|
||||
)
|
||||
|
||||
def record_completion(self, *, status: str, job_duration_seconds: float) -> None:
|
||||
attributes = self._attrs(status=status)
|
||||
self._add(self._job_runs_total, 1, attributes)
|
||||
self._record(self._job_duration_seconds, job_duration_seconds, attributes)
|
||||
|
||||
|
||||
class WorkflowRunCleanup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,6 +182,7 @@ class WorkflowRunCleanup:
|
||||
end_before: datetime.datetime | None = None,
|
||||
workflow_run_repo: APIWorkflowRunRepository | None = None,
|
||||
dry_run: bool = False,
|
||||
task_label: str = "custom",
|
||||
):
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise ValueError("start_from and end_before must be both set or both omitted.")
|
||||
@@ -46,6 +200,11 @@ class WorkflowRunCleanup:
|
||||
self.batch_size = batch_size
|
||||
self._cleanup_whitelist: set[str] | None = None
|
||||
self.dry_run = dry_run
|
||||
self._metrics = WorkflowRunCleanupMetrics(
|
||||
dry_run=dry_run,
|
||||
has_window=bool(start_from),
|
||||
task_label=task_label,
|
||||
)
|
||||
self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD
|
||||
self.workflow_run_repo: APIWorkflowRunRepository
|
||||
if workflow_run_repo:
|
||||
@@ -74,153 +233,193 @@ class WorkflowRunCleanup:
|
||||
related_totals = self._empty_related_counts() if self.dry_run else None
|
||||
batch_index = 0
|
||||
last_seen: tuple[datetime.datetime, str] | None = None
|
||||
status = "success"
|
||||
run_start = time.monotonic()
|
||||
max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL
|
||||
|
||||
max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200))
|
||||
try:
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
while True:
|
||||
batch_start = time.monotonic()
|
||||
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
fetch_start = time.monotonic()
|
||||
run_rows = self.workflow_run_repo.get_runs_batch_by_time_range(
|
||||
start_from=self.window_start,
|
||||
end_before=self.window_end,
|
||||
last_seen=last_seen,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
continue
|
||||
if not run_rows:
|
||||
logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1)
|
||||
break
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
batch_index += 1
|
||||
last_seen = (run_rows[-1].created_at, run_rows[-1].id)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
"workflow_run_cleanup (batch #%s): fetched %s rows in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
len(run_rows),
|
||||
int((time.monotonic() - fetch_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
|
||||
tenant_ids = {row.tenant_id for row in run_rows}
|
||||
|
||||
filter_start = time.monotonic()
|
||||
free_tenants = self._filter_free_tenants(tenant_ids)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms",
|
||||
batch_index,
|
||||
len(free_tenants),
|
||||
len(tenant_ids),
|
||||
int((time.monotonic() - filter_start) * 1000),
|
||||
)
|
||||
|
||||
free_runs = [row for row in run_rows if row.tenant_id in free_tenants]
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=0,
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts=None,
|
||||
related_action=None,
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
total_runs_targeted += len(free_runs)
|
||||
|
||||
if self.dry_run:
|
||||
count_start = time.monotonic()
|
||||
batch_counts = self.workflow_run_repo.count_runs_with_related(
|
||||
free_runs,
|
||||
count_node_executions=self._count_node_executions,
|
||||
count_trigger_logs=self._count_trigger_logs,
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - count_start) * 1000),
|
||||
)
|
||||
if related_totals is not None:
|
||||
for key in related_totals:
|
||||
related_totals[key] += batch_counts.get(key, 0)
|
||||
sample_ids = ", ".join(run.id for run in free_runs[:5])
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
batch_index,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=0,
|
||||
related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="would_delete",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] would delete {len(free_runs)} runs "
|
||||
f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown",
|
||||
fg="yellow",
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s, dry_run): batch total %sms",
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
delete_start = time.monotonic()
|
||||
counts = self.workflow_run_repo.delete_runs_with_related(
|
||||
free_runs,
|
||||
delete_node_executions=self._delete_node_executions,
|
||||
delete_trigger_logs=self._delete_trigger_logs,
|
||||
self._metrics.record_batch(
|
||||
batch_rows=len(run_rows),
|
||||
targeted_runs=len(free_runs),
|
||||
skipped_runs=paid_or_skipped,
|
||||
deleted_runs=counts["runs"],
|
||||
related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()},
|
||||
related_action="deleted",
|
||||
batch_duration_seconds=time.monotonic() - batch_start,
|
||||
)
|
||||
delete_ms = int((time.monotonic() - delete_start) * 1000)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0])
|
||||
raise
|
||||
|
||||
total_runs_deleted += counts["runs"]
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] deleted runs: {counts['runs']} "
|
||||
f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, "
|
||||
f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, "
|
||||
f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); "
|
||||
f"skipped {paid_or_skipped} paid/unknown",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"workflow_run_cleanup (batch #%s): delete %sms, batch total %sms",
|
||||
batch_index,
|
||||
delete_ms,
|
||||
int((time.monotonic() - batch_start) * 1000),
|
||||
)
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
# Random sleep between batches to avoid overwhelming the database
|
||||
sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311
|
||||
logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms)
|
||||
time.sleep(sleep_ms / 1000)
|
||||
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
if self.dry_run:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = (
|
||||
f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
)
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
summary_message = (
|
||||
f"Dry run complete. Would delete {total_runs_targeted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
if related_totals is not None:
|
||||
summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}"
|
||||
summary_color = "yellow"
|
||||
else:
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
if self.window_start:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}"
|
||||
)
|
||||
else:
|
||||
summary_message = (
|
||||
f"Cleanup complete. Deleted {total_runs_deleted} workflow runs "
|
||||
f"before {self.window_end.isoformat()}"
|
||||
)
|
||||
summary_color = "white"
|
||||
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
click.echo(click.style(summary_message, fg=summary_color))
|
||||
except Exception:
|
||||
status = "failed"
|
||||
raise
|
||||
finally:
|
||||
self._metrics.record_completion(
|
||||
status=status,
|
||||
job_duration_seconds=time.monotonic() - run_start,
|
||||
)
|
||||
|
||||
def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]:
|
||||
tenant_id_list = list(tenant_ids)
|
||||
|
||||
@@ -156,7 +156,8 @@ class VectorService:
|
||||
)
|
||||
# use full doc mode to generate segment's child chunk
|
||||
processing_rule_dict = processing_rule.to_dict()
|
||||
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
|
||||
if processing_rule_dict["rules"] is not None:
|
||||
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC
|
||||
documents = index_processor.transform(
|
||||
[document],
|
||||
embedding_model_instance=embedding_model_instance,
|
||||
|
||||
@@ -46,6 +46,7 @@ def test_absolute_mode_calls_from_time_range():
|
||||
end_before=end_before,
|
||||
batch_size=200,
|
||||
dry_run=True,
|
||||
task_label="custom",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
@@ -74,6 +75,7 @@ def test_relative_mode_before_days_only_calls_from_days():
|
||||
days=30,
|
||||
batch_size=500,
|
||||
dry_run=False,
|
||||
task_label="before-30",
|
||||
)
|
||||
mock_from_time_range.assert_not_called()
|
||||
|
||||
@@ -105,6 +107,7 @@ def test_relative_mode_with_from_days_ago_calls_from_time_range():
|
||||
end_before=fixed_now - datetime.timedelta(days=30),
|
||||
batch_size=1000,
|
||||
dry_run=False,
|
||||
task_label="60to30",
|
||||
)
|
||||
mock_from_days.assert_not_called()
|
||||
|
||||
|
||||
@@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
|
||||
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(is_valid=False)
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
@@ -743,6 +743,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id == "existing-cred"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
|
||||
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
|
||||
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
|
||||
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id is not None
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import dify_graph.graph_engine.response_coordinator.session as response_session_module
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType
|
||||
from dify_graph.graph_engine.response_coordinator import RESPONSE_SESSION_NODE_TYPES
|
||||
from dify_graph.graph_engine.response_coordinator.session import ResponseSession
|
||||
from dify_graph.nodes.base.template import Template, TextSegment
|
||||
|
||||
@@ -35,28 +33,14 @@ class DummyNodeWithoutStreamingTemplate:
|
||||
self.state = NodeState.UNKNOWN
|
||||
|
||||
|
||||
def test_response_session_from_node_rejects_node_types_outside_allowlist() -> None:
|
||||
"""Unsupported node types are rejected even if they expose a template."""
|
||||
def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None:
|
||||
"""Session creation depends on the streaming-template contract rather than node type."""
|
||||
node = DummyResponseNode(
|
||||
node_id="llm-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
template=Template(segments=[TextSegment(text="hello")]),
|
||||
)
|
||||
|
||||
with pytest.raises(TypeError, match="RESPONSE_SESSION_NODE_TYPES"):
|
||||
ResponseSession.from_node(node)
|
||||
|
||||
|
||||
def test_response_session_from_node_supports_downstream_allowlist_extension(monkeypatch) -> None:
|
||||
"""Downstream applications can extend the supported node-type list."""
|
||||
node = DummyResponseNode(
|
||||
node_id="llm-node",
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
template=Template(segments=[TextSegment(text="hello")]),
|
||||
)
|
||||
extended_node_types = [*RESPONSE_SESSION_NODE_TYPES, BuiltinNodeTypes.LLM]
|
||||
monkeypatch.setattr(response_session_module, "RESPONSE_SESSION_NODE_TYPES", extended_node_types)
|
||||
|
||||
session = ResponseSession.from_node(node)
|
||||
|
||||
assert session.node_id == "llm-node"
|
||||
|
||||
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
106
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
prompt_template = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="You are a classifier.",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
edition_type="basic",
|
||||
)
|
||||
]
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.fetch_model_schema",
|
||||
return_value=mock.MagicMock(features=[]),
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_list_messages",
|
||||
return_value=[SystemPromptMessage(content=content)],
|
||||
),
|
||||
mock.patch(
|
||||
"dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=None,
|
||||
sys_files=[],
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=["END"],
|
||||
memory_config=None,
|
||||
vision_enabled=False,
|
||||
vision_detail=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=None,
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")]
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain():
|
||||
prompt_messages, stop = _fetch_prompt_messages_with_mocked_content(
|
||||
[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
ImagePromptMessageContent(
|
||||
format="url",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
assert stop == ["END"]
|
||||
assert prompt_messages == [
|
||||
SystemPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="You are"),
|
||||
TextPromptMessageContent(data=" a classifier."),
|
||||
]
|
||||
)
|
||||
]
|
||||
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
19
api/tests/unit_tests/models/test_enums_creator_user_role.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
|
||||
def test_creator_user_role_missing_maps_hyphen_to_enum():
|
||||
# given an alias with hyphen
|
||||
value = "end-user"
|
||||
|
||||
# when converting to enum (invokes StrEnum._missing_ override)
|
||||
role = CreatorUserRole(value)
|
||||
|
||||
# then it should map to END_USER
|
||||
assert role is CreatorUserRole.END_USER
|
||||
|
||||
|
||||
def test_creator_user_role_missing_raises_for_unknown():
|
||||
with pytest.raises(ValueError):
|
||||
CreatorUserRole("unknown")
|
||||
@@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -282,7 +281,6 @@ class TestMessagesCleanService:
|
||||
MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"])
|
||||
assert mock_db_session.execute.call_count == 8 # 8 tables to clean up
|
||||
|
||||
@patch.dict(os.environ, {"SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL": "500"})
|
||||
def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy):
|
||||
service = MessagesCleanService(
|
||||
policy=mock_policy,
|
||||
@@ -301,9 +299,13 @@ class TestMessagesCleanService:
|
||||
mock_db_session.execute.side_effect = mock_returns
|
||||
mock_policy.filter_message_ids.return_value = ["msg1"]
|
||||
|
||||
with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep:
|
||||
with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform:
|
||||
mock_uniform.return_value = 300.0
|
||||
service.run()
|
||||
mock_uniform.assert_called_with(0, 500)
|
||||
mock_sleep.assert_called_with(0.3)
|
||||
with patch(
|
||||
"services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL",
|
||||
500,
|
||||
):
|
||||
with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep:
|
||||
with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform:
|
||||
mock_uniform.return_value = 300.0
|
||||
service.run()
|
||||
mock_uniform.assert_called_with(0, 500)
|
||||
mock_sleep.assert_called_with(0.3)
|
||||
|
||||
@@ -80,7 +80,13 @@ class TestWorkflowRunCleanupInit:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
with pytest.raises(ValueError):
|
||||
WorkflowRunCleanup(days=30, batch_size=10, start_from=dt, end_before=dt, workflow_run_repo=mock_repo)
|
||||
WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
start_from=dt,
|
||||
end_before=dt,
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
|
||||
def test_zero_batch_size_raises(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
@@ -102,10 +108,24 @@ class TestWorkflowRunCleanupInit:
|
||||
cfg.BILLING_ENABLED = False
|
||||
start = datetime.datetime(2024, 1, 1)
|
||||
end = datetime.datetime(2024, 6, 1)
|
||||
c = WorkflowRunCleanup(days=30, batch_size=5, start_from=start, end_before=end, workflow_run_repo=mock_repo)
|
||||
c = WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=5,
|
||||
start_from=start,
|
||||
end_before=end,
|
||||
workflow_run_repo=mock_repo,
|
||||
)
|
||||
assert c.window_start == start
|
||||
assert c.window_end == end
|
||||
|
||||
def test_default_task_label_is_custom(self, mock_repo):
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo)
|
||||
|
||||
assert c._metrics._base_attributes["task_label"] == "custom"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _empty_related_counts / _format_related_counts
|
||||
@@ -393,7 +413,12 @@ class TestRunDryRunMode:
|
||||
with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg:
|
||||
cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0
|
||||
cfg.BILLING_ENABLED = False
|
||||
return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo, dry_run=True)
|
||||
return WorkflowRunCleanup(
|
||||
days=30,
|
||||
batch_size=10,
|
||||
workflow_run_repo=mock_repo,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
def test_dry_run_no_delete_called(self, mock_repo):
|
||||
run = make_run("t1")
|
||||
|
||||
@@ -265,6 +265,61 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cleanup.run()
|
||||
|
||||
|
||||
def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
batches=[[FakeRun("run-free", "t_free", cutoff)]],
|
||||
delete_result={
|
||||
"runs": 0,
|
||||
"node_executions": 2,
|
||||
"offloads": 1,
|
||||
"app_logs": 3,
|
||||
"trigger_logs": 4,
|
||||
"pauses": 5,
|
||||
"pause_reasons": 6,
|
||||
},
|
||||
)
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
batch_calls: list[dict[str, object]] = []
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs))
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
cleanup.run()
|
||||
|
||||
assert len(batch_calls) == 1
|
||||
assert batch_calls[0]["batch_rows"] == 1
|
||||
assert batch_calls[0]["targeted_runs"] == 1
|
||||
assert batch_calls[0]["deleted_runs"] == 1
|
||||
assert batch_calls[0]["related_action"] == "deleted"
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
|
||||
def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FailingRepo(FakeRepo):
|
||||
def delete_runs_with_related(
|
||||
self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None
|
||||
) -> dict[str, int]:
|
||||
raise RuntimeError("delete failed")
|
||||
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]])
|
||||
cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10)
|
||||
monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False)
|
||||
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs))
|
||||
|
||||
with pytest.raises(RuntimeError, match="delete failed"):
|
||||
cleanup.run()
|
||||
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
cutoff = datetime.datetime.now()
|
||||
repo = FakeRepo(
|
||||
|
||||
@@ -540,6 +540,20 @@ class TestMessagesCleanServiceFromTimeRange:
|
||||
assert service._batch_size == 1000 # default
|
||||
assert service._dry_run is False # default
|
||||
|
||||
def test_explicit_task_label(self):
|
||||
start_from = datetime.datetime(2024, 1, 1)
|
||||
end_before = datetime.datetime(2024, 1, 2)
|
||||
policy = BillingDisabledPolicy()
|
||||
|
||||
service = MessagesCleanService.from_time_range(
|
||||
policy=policy,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
task_label="60to30",
|
||||
)
|
||||
|
||||
assert service._metrics._base_attributes["task_label"] == "60to30"
|
||||
|
||||
|
||||
class TestMessagesCleanServiceFromDays:
|
||||
"""Unit tests for MessagesCleanService.from_days factory method."""
|
||||
@@ -619,3 +633,54 @@ class TestMessagesCleanServiceFromDays:
|
||||
assert service._end_before == expected_end_before
|
||||
assert service._batch_size == 1000 # default
|
||||
assert service._dry_run is False # default
|
||||
assert service._metrics._base_attributes["task_label"] == "custom"
|
||||
|
||||
|
||||
class TestMessagesCleanServiceRun:
|
||||
"""Unit tests for MessagesCleanService.run instrumentation behavior."""
|
||||
|
||||
def test_run_records_completion_metrics_on_success(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
expected_stats = {
|
||||
"batches": 1,
|
||||
"total_messages": 10,
|
||||
"filtered_messages": 5,
|
||||
"total_deleted": 5,
|
||||
}
|
||||
service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act
|
||||
result = service.run()
|
||||
|
||||
# Assert
|
||||
assert result == expected_stats
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "success"
|
||||
|
||||
def test_run_records_completion_metrics_on_failure(self):
|
||||
# Arrange
|
||||
service = MessagesCleanService(
|
||||
policy=BillingDisabledPolicy(),
|
||||
start_from=datetime.datetime(2024, 1, 1),
|
||||
end_before=datetime.datetime(2024, 1, 2),
|
||||
batch_size=100,
|
||||
dry_run=False,
|
||||
)
|
||||
service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign]
|
||||
completion_calls: list[dict[str, object]] = []
|
||||
service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(RuntimeError, match="clean failed"):
|
||||
service.run()
|
||||
assert len(completion_calls) == 1
|
||||
assert completion_calls[0]["status"] == "failed"
|
||||
|
||||
8
api/uv.lock
generated
8
api/uv.lock
generated
@@ -457,14 +457,14 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.7"
|
||||
version = "1.6.9"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/49/dc/ed1681bf1339dd6ea1ce56136bad4baabc6f7ad466e375810702b0237047/authlib-1.6.7.tar.gz", hash = "sha256:dbf10100011d1e1b34048c9d120e83f13b35d69a826ae762b93d2fb5aafc337b", size = 164950, upload-time = "2026-02-06T14:04:14.171Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/af/98/00d3dd826d46959ad8e32af2dbb2398868fd9fd0683c26e56d0789bd0e68/authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04", size = 165134, upload-time = "2026-03-02T07:44:01.998Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/53/23/b65f568ed0c22f1efacb744d2db1a33c8068f384b8c9b482b52ebdbc3ef6/authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3", size = 244197, upload-time = "2026-03-02T07:44:00.307Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1533,7 +1533,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "dify-api"
|
||||
version = "1.13.0"
|
||||
version = "1.13.2"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "aliyun-log-python-sdk" },
|
||||
|
||||
@@ -21,7 +21,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -102,7 +102,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -132,7 +132,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.0
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -728,7 +728,7 @@ services:
|
||||
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -770,7 +770,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||
worker:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -809,7 +809,7 @@ services:
|
||||
# worker_beat service
|
||||
# Celery beat for scheduling periodic tasks.
|
||||
worker_beat:
|
||||
image: langgenius/dify-api:1.13.0
|
||||
image: langgenius/dify-api:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
# Use the shared environment variables.
|
||||
@@ -839,7 +839,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:1.13.0
|
||||
image: langgenius/dify-web:1.13.2
|
||||
restart: always
|
||||
environment:
|
||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||
|
||||
@@ -163,9 +163,38 @@ describe('check-components-diff-coverage helpers', () => {
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 2,
|
||||
total: 1,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 0, line: 33 },
|
||||
],
|
||||
})
|
||||
})
|
||||
|
||||
it('should require all branch arms when the branch condition changes', () => {
|
||||
const entry = {
|
||||
b: {
|
||||
0: [0, 0],
|
||||
},
|
||||
branchMap: {
|
||||
0: {
|
||||
line: 30,
|
||||
loc: { start: { line: 30 }, end: { line: 35 } },
|
||||
locations: [
|
||||
{ start: { line: 31 }, end: { line: 34 } },
|
||||
{ start: { line: 35 }, end: { line: 38 } },
|
||||
],
|
||||
type: 'if',
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const coverage = getChangedBranchCoverage(entry, new Set([30]))
|
||||
|
||||
expect(coverage).toEqual({
|
||||
covered: 0,
|
||||
total: 2,
|
||||
uncoveredBranches: [
|
||||
{ armIndex: 0, line: 31 },
|
||||
{ armIndex: 1, line: 35 },
|
||||
],
|
||||
})
|
||||
|
||||
72
web/__tests__/components-coverage-common.test.ts
Normal file
72
web/__tests__/components-coverage-common.test.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import {
|
||||
getCoverageStats,
|
||||
isRelevantTestFile,
|
||||
isTrackedComponentSourceFile,
|
||||
loadTrackedCoverageEntries,
|
||||
} from '../scripts/components-coverage-common.mjs'
|
||||
|
||||
describe('components coverage common helpers', () => {
|
||||
it('should identify tracked component source files and relevant tests', () => {
|
||||
const excludedComponentCoverageFiles = new Set([
|
||||
'web/app/components/share/types.ts',
|
||||
])
|
||||
|
||||
expect(isTrackedComponentSourceFile('web/app/components/share/index.tsx', excludedComponentCoverageFiles)).toBe(true)
|
||||
expect(isTrackedComponentSourceFile('web/app/components/share/types.ts', excludedComponentCoverageFiles)).toBe(false)
|
||||
expect(isTrackedComponentSourceFile('web/app/components/provider/index.tsx', excludedComponentCoverageFiles)).toBe(false)
|
||||
|
||||
expect(isRelevantTestFile('web/__tests__/share/text-generation-run-once-flow.test.tsx')).toBe(true)
|
||||
expect(isRelevantTestFile('web/app/components/share/__tests__/index.spec.tsx')).toBe(true)
|
||||
expect(isRelevantTestFile('web/utils/format.spec.ts')).toBe(false)
|
||||
})
|
||||
|
||||
it('should load only tracked coverage entries from mixed coverage paths', () => {
|
||||
const context = {
|
||||
excludedComponentCoverageFiles: new Set([
|
||||
'web/app/components/share/types.ts',
|
||||
]),
|
||||
repoRoot: '/repo',
|
||||
webRoot: '/repo/web',
|
||||
}
|
||||
const coverage = {
|
||||
'/repo/web/app/components/provider/index.tsx': {
|
||||
path: '/repo/web/app/components/provider/index.tsx',
|
||||
statementMap: { 0: { start: { line: 1 }, end: { line: 1 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
'app/components/share/index.tsx': {
|
||||
path: 'app/components/share/index.tsx',
|
||||
statementMap: { 0: { start: { line: 2 }, end: { line: 2 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
'app/components/share/types.ts': {
|
||||
path: 'app/components/share/types.ts',
|
||||
statementMap: { 0: { start: { line: 3 }, end: { line: 3 } } },
|
||||
s: { 0: 1 },
|
||||
},
|
||||
}
|
||||
|
||||
expect([...loadTrackedCoverageEntries(coverage, context).keys()]).toEqual([
|
||||
'web/app/components/share/index.tsx',
|
||||
])
|
||||
})
|
||||
|
||||
it('should calculate coverage stats using statement-derived line hits', () => {
|
||||
const entry = {
|
||||
b: { 0: [1, 0] },
|
||||
f: { 0: 1, 1: 0 },
|
||||
s: { 0: 1, 1: 0 },
|
||||
statementMap: {
|
||||
0: { start: { line: 10 }, end: { line: 10 } },
|
||||
1: { start: { line: 12 }, end: { line: 13 } },
|
||||
},
|
||||
}
|
||||
|
||||
expect(getCoverageStats(entry)).toEqual({
|
||||
branches: { covered: 1, total: 2 },
|
||||
functions: { covered: 1, total: 2 },
|
||||
lines: { covered: 1, total: 2 },
|
||||
statements: { covered: 1, total: 2 },
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -137,4 +137,31 @@ describe('SelectDataSet', () => {
|
||||
expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create')
|
||||
expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled()
|
||||
})
|
||||
|
||||
it('uses selectedIds as the initial modal selection', async () => {
|
||||
const datasetOne = makeDataset({
|
||||
id: 'set-1',
|
||||
name: 'Dataset One',
|
||||
})
|
||||
mockUseInfiniteDatasets.mockReturnValue({
|
||||
data: { pages: [{ data: [datasetOne] }] },
|
||||
isLoading: false,
|
||||
isFetchingNextPage: false,
|
||||
fetchNextPage: vi.fn(),
|
||||
hasNextPage: false,
|
||||
})
|
||||
|
||||
const onSelect = vi.fn()
|
||||
await act(async () => {
|
||||
render(<SelectDataSet {...baseProps} onSelect={onSelect} selectedIds={['set-1']} />)
|
||||
})
|
||||
|
||||
expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument()
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' }))
|
||||
})
|
||||
|
||||
expect(onSelect).toHaveBeenCalledWith([datasetOne])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,7 +4,7 @@ import type { DataSet } from '@/models/datasets'
|
||||
import { useInfiniteScroll } from 'ahooks'
|
||||
import Link from 'next/link'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Badge from '@/app/components/base/badge'
|
||||
@@ -31,17 +31,21 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
||||
onSelect,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [selected, setSelected] = useState<DataSet[]>([])
|
||||
const [selectedIdsInModal, setSelectedIdsInModal] = useState(() => selectedIds)
|
||||
const canSelectMulti = true
|
||||
const { formatIndexingTechniqueAndMethod } = useKnowledge()
|
||||
const { data, isLoading, isFetchingNextPage, fetchNextPage, hasNextPage } = useInfiniteDatasets(
|
||||
{ page: 1 },
|
||||
{ enabled: isShow, staleTime: 0, refetchOnMount: 'always' },
|
||||
)
|
||||
const pages = data?.pages || []
|
||||
const datasets = useMemo(() => {
|
||||
const pages = data?.pages || []
|
||||
return pages.flatMap(page => page.data.filter(item => item.indexing_technique || item.provider === 'external'))
|
||||
}, [pages])
|
||||
}, [data])
|
||||
const datasetMap = useMemo(() => new Map(datasets.map(item => [item.id, item])), [datasets])
|
||||
const selected = useMemo(() => {
|
||||
return selectedIdsInModal.map(id => datasetMap.get(id) || ({ id } as DataSet))
|
||||
}, [datasetMap, selectedIdsInModal])
|
||||
const hasNoData = !isLoading && datasets.length === 0
|
||||
|
||||
const listRef = useRef<HTMLDivElement>(null)
|
||||
@@ -61,50 +65,14 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
||||
},
|
||||
)
|
||||
|
||||
const prevSelectedIdsRef = useRef<string[]>([])
|
||||
const hasUserModifiedSelectionRef = useRef(false)
|
||||
useEffect(() => {
|
||||
if (isShow)
|
||||
hasUserModifiedSelectionRef.current = false
|
||||
}, [isShow])
|
||||
useEffect(() => {
|
||||
const prevSelectedIds = prevSelectedIdsRef.current
|
||||
const idsChanged = selectedIds.length !== prevSelectedIds.length
|
||||
|| selectedIds.some((id, idx) => id !== prevSelectedIds[idx])
|
||||
|
||||
if (!selectedIds.length && (!hasUserModifiedSelectionRef.current || idsChanged)) {
|
||||
setSelected([])
|
||||
prevSelectedIdsRef.current = selectedIds
|
||||
hasUserModifiedSelectionRef.current = false
|
||||
return
|
||||
}
|
||||
|
||||
if (!idsChanged && hasUserModifiedSelectionRef.current)
|
||||
return
|
||||
|
||||
setSelected((prev) => {
|
||||
const prevMap = new Map(prev.map(item => [item.id, item]))
|
||||
const nextSelected = selectedIds
|
||||
.map(id => datasets.find(item => item.id === id) || prevMap.get(id))
|
||||
.filter(Boolean) as DataSet[]
|
||||
return nextSelected
|
||||
})
|
||||
prevSelectedIdsRef.current = selectedIds
|
||||
hasUserModifiedSelectionRef.current = false
|
||||
}, [datasets, selectedIds])
|
||||
|
||||
const toggleSelect = (dataSet: DataSet) => {
|
||||
hasUserModifiedSelectionRef.current = true
|
||||
const isSelected = selected.some(item => item.id === dataSet.id)
|
||||
if (isSelected) {
|
||||
setSelected(selected.filter(item => item.id !== dataSet.id))
|
||||
}
|
||||
else {
|
||||
if (canSelectMulti)
|
||||
setSelected([...selected, dataSet])
|
||||
else
|
||||
setSelected([dataSet])
|
||||
}
|
||||
setSelectedIdsInModal((prev) => {
|
||||
const isSelected = prev.includes(dataSet.id)
|
||||
if (isSelected)
|
||||
return prev.filter(id => id !== dataSet.id)
|
||||
|
||||
return canSelectMulti ? [...prev, dataSet.id] : [dataSet.id]
|
||||
})
|
||||
}
|
||||
|
||||
const handleSelect = () => {
|
||||
@@ -126,7 +94,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
||||
|
||||
{hasNoData && (
|
||||
<div
|
||||
className="mt-6 flex h-[128px] items-center justify-center space-x-1 rounded-lg border text-[13px]"
|
||||
className="mt-6 flex h-[128px] items-center justify-center space-x-1 rounded-lg border text-[13px]"
|
||||
style={{
|
||||
background: 'rgba(0, 0, 0, 0.02)',
|
||||
borderColor: 'rgba(0, 0, 0, 0.02',
|
||||
@@ -145,7 +113,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
||||
key={item.id}
|
||||
className={cn(
|
||||
'flex h-10 cursor-pointer items-center rounded-lg border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg px-2 shadow-xs hover:border-components-panel-border hover:bg-components-panel-on-panel-item-bg-hover hover:shadow-sm',
|
||||
selected.some(i => i.id === item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs',
|
||||
selectedIdsInModal.includes(item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs',
|
||||
!item.embedding_available && 'hover:border-components-panel-border-subtle hover:bg-components-panel-on-panel-item-bg hover:shadow-xs',
|
||||
)}
|
||||
onClick={() => {
|
||||
@@ -195,7 +163,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
|
||||
)}
|
||||
{!isLoading && (
|
||||
<div className="mt-8 flex items-center justify-between">
|
||||
<div className="text-sm font-medium text-text-secondary">
|
||||
<div className="text-sm font-medium text-text-secondary">
|
||||
{selected.length > 0 && `${selected.length} ${t('feature.dataSet.selected', { ns: 'appDebug' })}`}
|
||||
</div>
|
||||
<div className="flex space-x-2">
|
||||
|
||||
@@ -137,5 +137,11 @@ describe('ScoreThresholdItem', () => {
|
||||
const input = screen.getByRole('textbox')
|
||||
expect(input).toHaveValue('1')
|
||||
})
|
||||
|
||||
it('should fall back to default value when value is undefined', () => {
|
||||
render(<ScoreThresholdItem {...defaultProps} value={undefined} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
expect(input).toHaveValue('0.7')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -6,7 +6,7 @@ import ParamItem from '.'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
value: number
|
||||
value?: number
|
||||
onChange: (key: string, value: number) => void
|
||||
enable: boolean
|
||||
hasSwitch?: boolean
|
||||
@@ -20,6 +20,18 @@ const VALUE_LIMIT = {
|
||||
max: 1,
|
||||
}
|
||||
|
||||
const normalizeScoreThreshold = (value?: number): number => {
|
||||
const normalizedValue = typeof value === 'number' && Number.isFinite(value)
|
||||
? value
|
||||
: VALUE_LIMIT.default
|
||||
const roundedValue = Number.parseFloat(normalizedValue.toFixed(2))
|
||||
|
||||
return Math.min(
|
||||
VALUE_LIMIT.max,
|
||||
Math.max(VALUE_LIMIT.min, roundedValue),
|
||||
)
|
||||
}
|
||||
|
||||
const ScoreThresholdItem: FC<Props> = ({
|
||||
className,
|
||||
value,
|
||||
@@ -29,16 +41,10 @@ const ScoreThresholdItem: FC<Props> = ({
|
||||
onSwitchChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const handleParamChange = (key: string, value: number) => {
|
||||
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
|
||||
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
|
||||
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
|
||||
onChange(key, notOutRangeValue)
|
||||
const handleParamChange = (key: string, nextValue: number) => {
|
||||
onChange(key, normalizeScoreThreshold(nextValue))
|
||||
}
|
||||
const safeValue = Math.min(
|
||||
VALUE_LIMIT.max,
|
||||
Math.max(VALUE_LIMIT.min, Number.parseFloat(value.toFixed(2))),
|
||||
)
|
||||
const safeValue = normalizeScoreThreshold(value)
|
||||
|
||||
return (
|
||||
<ParamItem
|
||||
|
||||
@@ -328,7 +328,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
|
||||
const setupHandlers = (overrides: { isPublicAPI?: boolean, isTimedOut?: () => boolean } = {}) => {
|
||||
let completionRes = ''
|
||||
let currentTaskId: string | null = null
|
||||
let isStopping = false
|
||||
@@ -359,6 +359,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => completionRes,
|
||||
getWorkflowProcessData: () => workflowProcessData,
|
||||
isPublicAPI: overrides.isPublicAPI ?? false,
|
||||
isTimedOut: overrides.isTimedOut ?? (() => false),
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -391,7 +392,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
}
|
||||
|
||||
it('should process workflow success and paused events', () => {
|
||||
const setup = setupHandlers()
|
||||
const setup = setupHandlers({ isPublicAPI: true })
|
||||
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
|
||||
|
||||
act(() => {
|
||||
@@ -546,7 +547,11 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
resultText: 'Hello',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
}))
|
||||
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
|
||||
expect(sseGetMock).toHaveBeenCalledWith(
|
||||
'/workflow/run-1/events',
|
||||
{},
|
||||
expect.objectContaining({ isPublicAPI: true }),
|
||||
)
|
||||
expect(setup.messageId()).toBe('run-1')
|
||||
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
|
||||
expect(setup.setRespondingFalse).toHaveBeenCalled()
|
||||
@@ -647,6 +652,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => '',
|
||||
getWorkflowProcessData: () => existingProcess,
|
||||
isPublicAPI: false,
|
||||
isTimedOut: () => false,
|
||||
markEnded: vi.fn(),
|
||||
notify: setup.notify,
|
||||
|
||||
@@ -351,6 +351,7 @@ describe('useResultSender', () => {
|
||||
await waitFor(() => {
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
getCompletionRes: harness.runState.getCompletionRes,
|
||||
isPublicAPI: true,
|
||||
resetRunState: harness.runState.resetRunState,
|
||||
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
|
||||
}))
|
||||
@@ -373,6 +374,30 @@ describe('useResultSender', () => {
|
||||
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should configure workflow handlers for installed apps as non-public', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
|
||||
const { result } = renderSender({
|
||||
appSourceType: AppSourceTypeEnum.installedApp,
|
||||
isWorkflow: true,
|
||||
runState: harness.runState,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(true)
|
||||
})
|
||||
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
isPublicAPI: false,
|
||||
}))
|
||||
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
|
||||
{ inputs: { name: 'Alice' } },
|
||||
expect.any(Object),
|
||||
AppSourceTypeEnum.installedApp,
|
||||
'app-1',
|
||||
)
|
||||
})
|
||||
|
||||
it('should stringify non-Error workflow failures', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sendWorkflowMessageMock.mockRejectedValue('workflow failed')
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type { ResultInputValue } from '../result-request'
|
||||
import type { ResultRunStateController } from './use-result-run-state'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
|
||||
import {
|
||||
AppSourceType,
|
||||
sendCompletionMessage,
|
||||
sendWorkflowMessage,
|
||||
} from '@/service/share'
|
||||
@@ -117,6 +117,7 @@ export const useResultSender = ({
|
||||
const otherOptions = createWorkflowStreamHandlers({
|
||||
getCompletionRes: runState.getCompletionRes,
|
||||
getWorkflowProcessData: runState.getWorkflowProcessData,
|
||||
isPublicAPI: appSourceType === AppSourceType.webApp,
|
||||
isTimedOut: () => isTimeout,
|
||||
markEnded: () => {
|
||||
isEnd = true
|
||||
|
||||
@@ -13,6 +13,7 @@ type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
type CreateWorkflowStreamHandlersParams = {
|
||||
getCompletionRes: () => string
|
||||
getWorkflowProcessData: () => WorkflowProcess | undefined
|
||||
isPublicAPI: boolean
|
||||
isTimedOut: () => boolean
|
||||
markEnded: () => void
|
||||
notify: Notify
|
||||
@@ -255,6 +256,7 @@ const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['out
|
||||
export const createWorkflowStreamHandlers = ({
|
||||
getCompletionRes,
|
||||
getWorkflowProcessData,
|
||||
isPublicAPI,
|
||||
isTimedOut,
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -287,6 +289,7 @@ export const createWorkflowStreamHandlers = ({
|
||||
}
|
||||
|
||||
const otherOptions: IOtherOptions = {
|
||||
isPublicAPI,
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
if (workflowProcessData?.tracing.length) {
|
||||
@@ -378,6 +381,7 @@ export const createWorkflowStreamHandlers = ({
|
||||
},
|
||||
onWorkflowPaused: ({ data }) => {
|
||||
tempMessageId = data.workflow_run_id
|
||||
// WebApp workflows must keep using the public API namespace after pause/resume.
|
||||
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
|
||||
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
|
||||
},
|
||||
|
||||
@@ -203,7 +203,7 @@ const NodeSelector: FC<NodeSelectorProps> = ({
|
||||
)
|
||||
}
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent className="z-[1000]">
|
||||
<PortalToFollowElemContent className="z-[1002]">
|
||||
<div className={`rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg ${popupClassName}`}>
|
||||
<Tabs
|
||||
tabs={tabs}
|
||||
|
||||
@@ -169,7 +169,7 @@ const ToolPicker: FC<Props> = ({
|
||||
{trigger}
|
||||
</PortalToFollowElemTrigger>
|
||||
|
||||
<PortalToFollowElemContent className="z-[1000]">
|
||||
<PortalToFollowElemContent className="z-[1002]">
|
||||
<div className={cn('relative min-h-20 rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-sm', panelClassName)}>
|
||||
<div className="p-2 pb-1">
|
||||
<SearchBox
|
||||
|
||||
@@ -117,7 +117,7 @@ const CodeEditor: FC<CodeEditorProps> = ({
|
||||
<div className={cn('flex h-full flex-col overflow-hidden bg-components-input-bg-normal', hideTopMenu && 'pt-2', className)}>
|
||||
{!hideTopMenu && (
|
||||
<div className="flex items-center justify-between pl-2 pr-1 pt-1">
|
||||
<div className="py-0.5 text-text-secondary system-xs-semibold-uppercase">
|
||||
<div className="system-xs-semibold-uppercase py-0.5 text-text-secondary">
|
||||
<span className="px-1 py-0.5">JSON</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-x-0.5">
|
||||
|
||||
@@ -928,12 +928,6 @@
|
||||
"app/components/app/configuration/dataset-config/select-dataset/index.tsx": {
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 2
|
||||
},
|
||||
"tailwindcss/no-unnecessary-whitespace": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/app/configuration/dataset-config/settings-modal/index.tsx": {
|
||||
@@ -7881,6 +7875,9 @@
|
||||
"no-restricted-imports": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 4
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"type": "module",
|
||||
"version": "1.13.0",
|
||||
"version": "1.13.2",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.32.1",
|
||||
"imports": {
|
||||
|
||||
@@ -2,6 +2,11 @@ import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
const DIFF_COVERAGE_IGNORE_LINE_TOKEN = 'diff-coverage-ignore-line:'
|
||||
const DEFAULT_BRANCH_REF_CANDIDATES = ['origin/main', 'main']
|
||||
|
||||
export function normalizeDiffRangeMode(mode) {
|
||||
return mode === 'exact' ? 'exact' : 'merge-base'
|
||||
}
|
||||
|
||||
export function buildGitDiffRevisionArgs(base, head, mode = 'merge-base') {
|
||||
return mode === 'exact'
|
||||
@@ -9,6 +14,46 @@ export function buildGitDiffRevisionArgs(base, head, mode = 'merge-base') {
|
||||
: [`${base}...${head}`]
|
||||
}
|
||||
|
||||
export function resolveGitDiffContext({
|
||||
base,
|
||||
head,
|
||||
mode = 'merge-base',
|
||||
execGit,
|
||||
}) {
|
||||
const requestedMode = normalizeDiffRangeMode(mode)
|
||||
const context = {
|
||||
base,
|
||||
head,
|
||||
mode: requestedMode,
|
||||
requestedMode,
|
||||
reason: null,
|
||||
useCombinedMergeDiff: false,
|
||||
}
|
||||
|
||||
if (requestedMode !== 'exact' || !base || !head || !execGit)
|
||||
return context
|
||||
|
||||
const baseCommit = resolveCommitSha(base, execGit) ?? base
|
||||
const headCommit = resolveCommitSha(head, execGit) ?? head
|
||||
const parents = getCommitParents(headCommit, execGit)
|
||||
if (parents.length < 2)
|
||||
return context
|
||||
|
||||
const [firstParent, secondParent] = parents
|
||||
if (firstParent !== baseCommit)
|
||||
return context
|
||||
|
||||
const defaultBranchRef = resolveDefaultBranchRef(execGit)
|
||||
if (!defaultBranchRef || !isAncestor(secondParent, defaultBranchRef, execGit))
|
||||
return context
|
||||
|
||||
return {
|
||||
...context,
|
||||
reason: `ignored merge from ${defaultBranchRef}`,
|
||||
useCombinedMergeDiff: true,
|
||||
}
|
||||
}
|
||||
|
||||
export function parseChangedLineMap(diff, isTrackedComponentSourceFile) {
|
||||
const lineMap = new Map()
|
||||
let currentFile = null
|
||||
@@ -22,7 +67,7 @@ export function parseChangedLineMap(diff, isTrackedComponentSourceFile) {
|
||||
if (!currentFile || !isTrackedComponentSourceFile(currentFile))
|
||||
continue
|
||||
|
||||
const match = line.match(/^@@ -\d+(?:,\d+)? \+(\d+)(?:,(\d+))? @@/)
|
||||
const match = line.match(/^@{2,}(?: -\d+(?:,\d+)?)+ \+(\d+)(?:,(\d+))? @{2,}/)
|
||||
if (!match)
|
||||
continue
|
||||
|
||||
@@ -131,14 +176,15 @@ export function getChangedBranchCoverage(entry, changedLines) {
|
||||
let total = 0
|
||||
|
||||
for (const [branchId, branch] of Object.entries(entry.branchMap ?? {})) {
|
||||
if (!branchIntersectsChangedLines(branch, changedLines))
|
||||
continue
|
||||
|
||||
const hits = Array.isArray(entry.b?.[branchId]) ? entry.b[branchId] : []
|
||||
const locations = getBranchLocations(branch)
|
||||
const armCount = Math.max(locations.length, hits.length)
|
||||
const impactedArmIndexes = getImpactedBranchArmIndexes(branch, changedLines, armCount)
|
||||
|
||||
for (let armIndex = 0; armIndex < armCount; armIndex += 1) {
|
||||
if (impactedArmIndexes.length === 0)
|
||||
continue
|
||||
|
||||
for (const armIndex of impactedArmIndexes) {
|
||||
total += 1
|
||||
if ((hits[armIndex] ?? 0) > 0) {
|
||||
covered += 1
|
||||
@@ -219,24 +265,99 @@ function emptyIgnoreResult(changedLines = []) {
|
||||
}
|
||||
}
|
||||
|
||||
function branchIntersectsChangedLines(branch, changedLines) {
|
||||
if (!changedLines || changedLines.size === 0)
|
||||
function getCommitParents(ref, execGit) {
|
||||
const output = tryExecGit(execGit, ['rev-list', '--parents', '-n', '1', ref])
|
||||
if (!output)
|
||||
return []
|
||||
|
||||
return output
|
||||
.trim()
|
||||
.split(/\s+/)
|
||||
.slice(1)
|
||||
}
|
||||
|
||||
function resolveCommitSha(ref, execGit) {
|
||||
return tryExecGit(execGit, ['rev-parse', '--verify', ref])?.trim() ?? null
|
||||
}
|
||||
|
||||
function resolveDefaultBranchRef(execGit) {
|
||||
const originHeadRef = tryExecGit(execGit, ['symbolic-ref', '--quiet', '--short', 'refs/remotes/origin/HEAD'])?.trim()
|
||||
if (originHeadRef)
|
||||
return originHeadRef
|
||||
|
||||
for (const ref of DEFAULT_BRANCH_REF_CANDIDATES) {
|
||||
if (tryExecGit(execGit, ['rev-parse', '--verify', '-q', ref]))
|
||||
return ref
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
function isAncestor(ancestorRef, descendantRef, execGit) {
|
||||
try {
|
||||
execGit(['merge-base', '--is-ancestor', ancestorRef, descendantRef])
|
||||
return true
|
||||
}
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if (rangeIntersectsChangedLines(branch.loc, changedLines))
|
||||
return true
|
||||
|
||||
const locations = getBranchLocations(branch)
|
||||
if (locations.some(location => rangeIntersectsChangedLines(location, changedLines)))
|
||||
return true
|
||||
|
||||
return branch.line ? changedLines.has(branch.line) : false
|
||||
function tryExecGit(execGit, args) {
|
||||
try {
|
||||
return execGit(args)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function getBranchLocations(branch) {
|
||||
return Array.isArray(branch?.locations) ? branch.locations.filter(Boolean) : []
|
||||
}
|
||||
|
||||
function getImpactedBranchArmIndexes(branch, changedLines, armCount) {
|
||||
if (!changedLines || changedLines.size === 0 || armCount === 0)
|
||||
return []
|
||||
|
||||
const locations = getBranchLocations(branch)
|
||||
if (isWholeBranchTouched(branch, changedLines, locations, armCount))
|
||||
return Array.from({ length: armCount }, (_, armIndex) => armIndex)
|
||||
|
||||
const impactedArmIndexes = []
|
||||
for (let armIndex = 0; armIndex < armCount; armIndex += 1) {
|
||||
const location = locations[armIndex]
|
||||
if (rangeIntersectsChangedLines(location, changedLines))
|
||||
impactedArmIndexes.push(armIndex)
|
||||
}
|
||||
|
||||
return impactedArmIndexes
|
||||
}
|
||||
|
||||
function isWholeBranchTouched(branch, changedLines, locations, armCount) {
|
||||
if (!changedLines || changedLines.size === 0)
|
||||
return false
|
||||
|
||||
if (branch.line && changedLines.has(branch.line))
|
||||
return true
|
||||
|
||||
const branchRange = branch.loc ?? branch
|
||||
if (!rangeIntersectsChangedLines(branchRange, changedLines))
|
||||
return false
|
||||
|
||||
if (locations.length === 0 || locations.length < armCount)
|
||||
return true
|
||||
|
||||
for (const lineNumber of changedLines) {
|
||||
if (!lineTouchesLocation(lineNumber, branchRange))
|
||||
continue
|
||||
if (!locations.some(location => lineTouchesLocation(lineNumber, location)))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function rangeIntersectsChangedLines(location, changedLines) {
|
||||
if (!location || !changedLines || changedLines.size === 0)
|
||||
return false
|
||||
@@ -268,6 +389,15 @@ function getFirstChangedLineInRange(location, changedLines, fallbackLine = 1) {
|
||||
return startLine ?? fallbackLine
|
||||
}
|
||||
|
||||
function lineTouchesLocation(lineNumber, location) {
|
||||
const startLine = getLocationStartLine(location)
|
||||
const endLine = getLocationEndLine(location) ?? startLine
|
||||
if (!startLine || !endLine)
|
||||
return false
|
||||
|
||||
return lineNumber >= startLine && lineNumber <= endLine
|
||||
}
|
||||
|
||||
function getLocationStartLine(location) {
|
||||
return location?.start?.line ?? location?.line ?? null
|
||||
}
|
||||
|
||||
118
web/scripts/check-components-diff-coverage-lib.spec.ts
Normal file
118
web/scripts/check-components-diff-coverage-lib.spec.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { parseChangedLineMap, resolveGitDiffContext } from './check-components-diff-coverage-lib.mjs'
|
||||
|
||||
function createExecGitMock(responses: Record<string, string | Error>) {
|
||||
return vi.fn((args: string[]) => {
|
||||
const key = args.join(' ')
|
||||
const response = responses[key]
|
||||
|
||||
if (response instanceof Error)
|
||||
throw response
|
||||
|
||||
if (response === undefined)
|
||||
throw new Error(`Unexpected git args: ${key}`)
|
||||
|
||||
return response
|
||||
})
|
||||
}
|
||||
|
||||
describe('resolveGitDiffContext', () => {
|
||||
it('switches exact diff to combined merge diff when head merges origin/main into the branch', () => {
|
||||
const execGit = createExecGitMock({
|
||||
'rev-parse --verify feature-parent-sha': 'feature-parent-sha\n',
|
||||
'rev-parse --verify merge-sha': 'merge-sha\n',
|
||||
'rev-list --parents -n 1 merge-sha': 'merge-sha feature-parent-sha main-parent-sha\n',
|
||||
'symbolic-ref --quiet --short refs/remotes/origin/HEAD': 'origin/main\n',
|
||||
'merge-base --is-ancestor main-parent-sha origin/main': '',
|
||||
})
|
||||
|
||||
expect(resolveGitDiffContext({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
execGit,
|
||||
})).toEqual({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
requestedMode: 'exact',
|
||||
reason: 'ignored merge from origin/main',
|
||||
useCombinedMergeDiff: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to origin/main when origin/HEAD is unavailable', () => {
|
||||
const execGit = createExecGitMock({
|
||||
'rev-parse --verify feature-parent-sha': 'feature-parent-sha\n',
|
||||
'rev-parse --verify merge-sha': 'merge-sha\n',
|
||||
'rev-list --parents -n 1 merge-sha': 'merge-sha feature-parent-sha main-parent-sha\n',
|
||||
'symbolic-ref --quiet --short refs/remotes/origin/HEAD': new Error('missing origin/HEAD'),
|
||||
'rev-parse --verify -q origin/main': 'main-tip-sha\n',
|
||||
'merge-base --is-ancestor main-parent-sha origin/main': '',
|
||||
})
|
||||
|
||||
expect(resolveGitDiffContext({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
execGit,
|
||||
})).toEqual({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
requestedMode: 'exact',
|
||||
reason: 'ignored merge from origin/main',
|
||||
useCombinedMergeDiff: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('keeps exact diff when the second parent is not the default branch', () => {
|
||||
const execGit = createExecGitMock({
|
||||
'rev-parse --verify feature-parent-sha': 'feature-parent-sha\n',
|
||||
'rev-parse --verify merge-sha': 'merge-sha\n',
|
||||
'rev-list --parents -n 1 merge-sha': 'merge-sha feature-parent-sha topic-parent-sha\n',
|
||||
'symbolic-ref --quiet --short refs/remotes/origin/HEAD': 'origin/main\n',
|
||||
'merge-base --is-ancestor topic-parent-sha origin/main': new Error('not ancestor'),
|
||||
})
|
||||
|
||||
expect(resolveGitDiffContext({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
execGit,
|
||||
})).toEqual({
|
||||
base: 'feature-parent-sha',
|
||||
head: 'merge-sha',
|
||||
mode: 'exact',
|
||||
requestedMode: 'exact',
|
||||
reason: null,
|
||||
useCombinedMergeDiff: false,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('parseChangedLineMap', () => {
|
||||
it('parses regular diff hunks', () => {
|
||||
const diff = [
|
||||
'diff --git a/web/app/components/example.tsx b/web/app/components/example.tsx',
|
||||
'+++ b/web/app/components/example.tsx',
|
||||
'@@ -10,0 +11,2 @@',
|
||||
].join('\n')
|
||||
|
||||
const changedLineMap = parseChangedLineMap(diff, () => true)
|
||||
|
||||
expect([...changedLineMap.get('web/app/components/example.tsx') ?? []]).toEqual([11, 12])
|
||||
})
|
||||
|
||||
it('parses combined merge diff hunks', () => {
|
||||
const diff = [
|
||||
'diff --cc web/app/components/example.tsx',
|
||||
'+++ b/web/app/components/example.tsx',
|
||||
'@@@ -10,0 -10,0 +11,3 @@@',
|
||||
].join('\n')
|
||||
|
||||
const changedLineMap = parseChangedLineMap(diff, () => true)
|
||||
|
||||
expect([...changedLineMap.get('web/app/components/example.tsx') ?? []]).toEqual([11, 12, 13])
|
||||
})
|
||||
})
|
||||
@@ -6,41 +6,36 @@ import {
|
||||
getChangedBranchCoverage,
|
||||
getChangedStatementCoverage,
|
||||
getIgnoredChangedLinesFromFile,
|
||||
getLineHits,
|
||||
normalizeToRepoRelative,
|
||||
normalizeDiffRangeMode,
|
||||
parseChangedLineMap,
|
||||
resolveGitDiffContext,
|
||||
} from './check-components-diff-coverage-lib.mjs'
|
||||
import { COMPONENT_COVERAGE_EXCLUDE_LABEL } from './component-coverage-filters.mjs'
|
||||
import {
|
||||
collectComponentCoverageExcludedFiles,
|
||||
COMPONENT_COVERAGE_EXCLUDE_LABEL,
|
||||
} from './component-coverage-filters.mjs'
|
||||
import {
|
||||
COMPONENTS_GLOBAL_THRESHOLDS,
|
||||
EXCLUDED_COMPONENT_MODULES,
|
||||
getComponentModuleThreshold,
|
||||
} from './components-coverage-thresholds.mjs'
|
||||
APP_COMPONENTS_PREFIX,
|
||||
createComponentCoverageContext,
|
||||
getModuleName,
|
||||
isAnyComponentSourceFile,
|
||||
isExcludedComponentSourceFile,
|
||||
isTrackedComponentSourceFile,
|
||||
loadTrackedCoverageEntries,
|
||||
} from './components-coverage-common.mjs'
|
||||
import { EXCLUDED_COMPONENT_MODULES } from './components-coverage-thresholds.mjs'
|
||||
|
||||
const APP_COMPONENTS_PREFIX = 'web/app/components/'
|
||||
const APP_COMPONENTS_COVERAGE_PREFIX = 'app/components/'
|
||||
const SHARED_TEST_PREFIX = 'web/__tests__/'
|
||||
const STRICT_TEST_FILE_TOUCH = process.env.STRICT_COMPONENT_TEST_TOUCH === 'true'
|
||||
const REQUESTED_DIFF_RANGE_MODE = normalizeDiffRangeMode(process.env.DIFF_RANGE_MODE)
|
||||
const EXCLUDED_MODULES_LABEL = [...EXCLUDED_COMPONENT_MODULES].sort().join(', ')
|
||||
const DIFF_RANGE_MODE = process.env.DIFF_RANGE_MODE === 'exact' ? 'exact' : 'merge-base'
|
||||
|
||||
const repoRoot = repoRootFromCwd()
|
||||
const webRoot = path.join(repoRoot, 'web')
|
||||
const excludedComponentCoverageFiles = new Set(
|
||||
collectComponentCoverageExcludedFiles(path.join(webRoot, 'app/components'), { pathPrefix: 'web/app/components' }),
|
||||
)
|
||||
const context = createComponentCoverageContext(repoRoot)
|
||||
const baseSha = process.env.BASE_SHA?.trim()
|
||||
const headSha = process.env.HEAD_SHA?.trim() || 'HEAD'
|
||||
const coverageFinalPath = path.join(webRoot, 'coverage', 'coverage-final.json')
|
||||
const coverageFinalPath = path.join(context.webRoot, 'coverage', 'coverage-final.json')
|
||||
|
||||
if (!baseSha || /^0+$/.test(baseSha)) {
|
||||
appendSummary([
|
||||
'### app/components Diff Coverage',
|
||||
'### app/components Pure Diff Coverage',
|
||||
'',
|
||||
'Skipped diff coverage check because `BASE_SHA` was not available.',
|
||||
'Skipped pure diff coverage check because `BASE_SHA` was not available.',
|
||||
])
|
||||
process.exit(0)
|
||||
}
|
||||
@@ -50,55 +45,36 @@ if (!fs.existsSync(coverageFinalPath)) {
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
const diffContext = resolveGitDiffContext({
|
||||
base: baseSha,
|
||||
head: headSha,
|
||||
mode: REQUESTED_DIFF_RANGE_MODE,
|
||||
execGit,
|
||||
})
|
||||
const coverage = JSON.parse(fs.readFileSync(coverageFinalPath, 'utf8'))
|
||||
const changedFiles = getChangedFiles(baseSha, headSha)
|
||||
const changedFiles = getChangedFiles(diffContext)
|
||||
const changedComponentSourceFiles = changedFiles.filter(isAnyComponentSourceFile)
|
||||
const changedSourceFiles = changedComponentSourceFiles.filter(isTrackedComponentSourceFile)
|
||||
const changedExcludedSourceFiles = changedComponentSourceFiles.filter(isExcludedComponentSourceFile)
|
||||
const changedTestFiles = changedFiles.filter(isRelevantTestFile)
|
||||
const changedSourceFiles = changedComponentSourceFiles.filter(filePath => isTrackedComponentSourceFile(filePath, context.excludedComponentCoverageFiles))
|
||||
const changedExcludedSourceFiles = changedComponentSourceFiles.filter(filePath => isExcludedComponentSourceFile(filePath, context.excludedComponentCoverageFiles))
|
||||
|
||||
if (changedSourceFiles.length === 0) {
|
||||
appendSummary(buildSkipSummary(changedExcludedSourceFiles))
|
||||
process.exit(0)
|
||||
}
|
||||
|
||||
const coverageEntries = new Map()
|
||||
for (const [file, entry] of Object.entries(coverage)) {
|
||||
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file, {
|
||||
appComponentsCoveragePrefix: APP_COMPONENTS_COVERAGE_PREFIX,
|
||||
appComponentsPrefix: APP_COMPONENTS_PREFIX,
|
||||
repoRoot,
|
||||
sharedTestPrefix: SHARED_TEST_PREFIX,
|
||||
webRoot,
|
||||
})
|
||||
if (!isTrackedComponentSourceFile(repoRelativePath))
|
||||
continue
|
||||
|
||||
coverageEntries.set(repoRelativePath, entry)
|
||||
}
|
||||
|
||||
const fileCoverageRows = []
|
||||
const moduleCoverageMap = new Map()
|
||||
|
||||
for (const [file, entry] of coverageEntries.entries()) {
|
||||
const stats = getCoverageStats(entry)
|
||||
const moduleName = getModuleName(file)
|
||||
fileCoverageRows.push({ file, moduleName, ...stats })
|
||||
mergeCoverageStats(moduleCoverageMap, moduleName, stats)
|
||||
}
|
||||
|
||||
const overallCoverage = sumCoverageStats(fileCoverageRows)
|
||||
const diffChanges = getChangedLineMap(baseSha, headSha)
|
||||
const coverageEntries = loadTrackedCoverageEntries(coverage, context)
|
||||
const diffChanges = getChangedLineMap(diffContext)
|
||||
const diffRows = []
|
||||
const ignoredDiffLines = []
|
||||
const invalidIgnorePragmas = []
|
||||
|
||||
for (const [file, changedLines] of diffChanges.entries()) {
|
||||
if (!isTrackedComponentSourceFile(file))
|
||||
if (!isTrackedComponentSourceFile(file, context.excludedComponentCoverageFiles))
|
||||
continue
|
||||
|
||||
const entry = coverageEntries.get(file)
|
||||
const ignoreInfo = getIgnoredChangedLinesFromFile(path.join(repoRoot, file), changedLines)
|
||||
|
||||
for (const [line, reason] of ignoreInfo.ignoredLines.entries()) {
|
||||
ignoredDiffLines.push({
|
||||
file,
|
||||
@@ -106,6 +82,7 @@ for (const [file, changedLines] of diffChanges.entries()) {
|
||||
reason,
|
||||
})
|
||||
}
|
||||
|
||||
for (const invalidPragma of ignoreInfo.invalidPragmas) {
|
||||
invalidIgnorePragmas.push({
|
||||
file,
|
||||
@@ -137,40 +114,16 @@ const diffTotals = diffRows.reduce((acc, row) => {
|
||||
|
||||
const diffStatementFailures = diffRows.filter(row => row.statements.uncoveredLines.length > 0)
|
||||
const diffBranchFailures = diffRows.filter(row => row.branches.uncoveredBranches.length > 0)
|
||||
const overallThresholdFailures = getThresholdFailures(overallCoverage, COMPONENTS_GLOBAL_THRESHOLDS)
|
||||
const moduleCoverageRows = [...moduleCoverageMap.entries()]
|
||||
.map(([moduleName, stats]) => ({
|
||||
moduleName,
|
||||
stats,
|
||||
thresholds: getComponentModuleThreshold(moduleName),
|
||||
}))
|
||||
.map(row => ({
|
||||
...row,
|
||||
failures: row.thresholds ? getThresholdFailures(row.stats, row.thresholds) : [],
|
||||
}))
|
||||
const moduleThresholdFailures = moduleCoverageRows
|
||||
.filter(row => row.failures.length > 0)
|
||||
.flatMap(row => row.failures.map(failure => ({
|
||||
moduleName: row.moduleName,
|
||||
...failure,
|
||||
})))
|
||||
const hasRelevantTestChanges = changedTestFiles.length > 0
|
||||
const missingTestTouch = !hasRelevantTestChanges
|
||||
|
||||
appendSummary(buildSummary({
|
||||
overallCoverage,
|
||||
overallThresholdFailures,
|
||||
moduleCoverageRows,
|
||||
moduleThresholdFailures,
|
||||
changedSourceFiles,
|
||||
diffContext,
|
||||
diffBranchFailures,
|
||||
diffRows,
|
||||
diffStatementFailures,
|
||||
diffTotals,
|
||||
changedSourceFiles,
|
||||
changedTestFiles,
|
||||
ignoredDiffLines,
|
||||
invalidIgnorePragmas,
|
||||
missingTestTouch,
|
||||
}))
|
||||
|
||||
if (process.env.CI) {
|
||||
@@ -178,107 +131,51 @@ if (process.env.CI) {
|
||||
const firstLine = failure.statements.uncoveredLines[0] ?? 1
|
||||
console.log(`::error file=${failure.file},line=${firstLine}::Uncovered changed statements: ${formatLineRanges(failure.statements.uncoveredLines)}`)
|
||||
}
|
||||
|
||||
for (const failure of diffBranchFailures.slice(0, 20)) {
|
||||
const firstBranch = failure.branches.uncoveredBranches[0]
|
||||
const line = firstBranch?.line ?? 1
|
||||
console.log(`::error file=${failure.file},line=${line}::Uncovered changed branches: ${formatBranchRefs(failure.branches.uncoveredBranches)}`)
|
||||
}
|
||||
|
||||
for (const invalidPragma of invalidIgnorePragmas.slice(0, 20)) {
|
||||
console.log(`::error file=${invalidPragma.file},line=${invalidPragma.line}::Invalid diff coverage ignore pragma: ${invalidPragma.reason}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
overallThresholdFailures.length > 0
|
||||
|| moduleThresholdFailures.length > 0
|
||||
|| diffStatementFailures.length > 0
|
||||
diffStatementFailures.length > 0
|
||||
|| diffBranchFailures.length > 0
|
||||
|| invalidIgnorePragmas.length > 0
|
||||
|| (STRICT_TEST_FILE_TOUCH && missingTestTouch)
|
||||
) {
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
function buildSummary({
|
||||
overallCoverage,
|
||||
overallThresholdFailures,
|
||||
moduleCoverageRows,
|
||||
moduleThresholdFailures,
|
||||
changedSourceFiles,
|
||||
diffContext,
|
||||
diffBranchFailures,
|
||||
diffRows,
|
||||
diffStatementFailures,
|
||||
diffTotals,
|
||||
changedSourceFiles,
|
||||
changedTestFiles,
|
||||
ignoredDiffLines,
|
||||
invalidIgnorePragmas,
|
||||
missingTestTouch,
|
||||
}) {
|
||||
const lines = [
|
||||
'### app/components Diff Coverage',
|
||||
'### app/components Pure Diff Coverage',
|
||||
'',
|
||||
`Compared \`${baseSha.slice(0, 12)}\` -> \`${headSha.slice(0, 12)}\``,
|
||||
`Diff range mode: \`${DIFF_RANGE_MODE}\``,
|
||||
...buildDiffContextSummary(diffContext),
|
||||
'',
|
||||
`Excluded modules: \`${EXCLUDED_MODULES_LABEL}\``,
|
||||
`Excluded file kinds: \`${COMPONENT_COVERAGE_EXCLUDE_LABEL}\``,
|
||||
'',
|
||||
'| Check | Result | Details |',
|
||||
'|---|---:|---|',
|
||||
`| Overall tracked lines | ${formatPercent(overallCoverage.lines)} | ${overallCoverage.lines.covered}/${overallCoverage.lines.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.lines}% |`,
|
||||
`| Overall tracked statements | ${formatPercent(overallCoverage.statements)} | ${overallCoverage.statements.covered}/${overallCoverage.statements.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.statements}% |`,
|
||||
`| Overall tracked functions | ${formatPercent(overallCoverage.functions)} | ${overallCoverage.functions.covered}/${overallCoverage.functions.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.functions}% |`,
|
||||
`| Overall tracked branches | ${formatPercent(overallCoverage.branches)} | ${overallCoverage.branches.covered}/${overallCoverage.branches.total}; threshold ${COMPONENTS_GLOBAL_THRESHOLDS.branches}% |`,
|
||||
`| Changed statements | ${formatDiffPercent(diffTotals.statements)} | ${diffTotals.statements.covered}/${diffTotals.statements.total} |`,
|
||||
`| Changed branches | ${formatDiffPercent(diffTotals.branches)} | ${diffTotals.branches.covered}/${diffTotals.branches.total} |`,
|
||||
'',
|
||||
]
|
||||
|
||||
if (overallThresholdFailures.length > 0) {
|
||||
lines.push('Overall thresholds failed:')
|
||||
for (const failure of overallThresholdFailures)
|
||||
lines.push(`- ${failure.metric}: ${failure.actual.toFixed(2)}% < ${failure.expected}%`)
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (moduleThresholdFailures.length > 0) {
|
||||
lines.push('Module thresholds failed:')
|
||||
for (const failure of moduleThresholdFailures)
|
||||
lines.push(`- ${failure.moduleName} ${failure.metric}: ${failure.actual.toFixed(2)}% < ${failure.expected}%`)
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
const moduleRows = moduleCoverageRows
|
||||
.map(({ moduleName, stats, thresholds, failures }) => ({
|
||||
moduleName,
|
||||
lines: percentage(stats.lines.covered, stats.lines.total),
|
||||
statements: percentage(stats.statements.covered, stats.statements.total),
|
||||
functions: percentage(stats.functions.covered, stats.functions.total),
|
||||
branches: percentage(stats.branches.covered, stats.branches.total),
|
||||
thresholds,
|
||||
failures,
|
||||
}))
|
||||
.sort((a, b) => {
|
||||
if (a.failures.length !== b.failures.length)
|
||||
return b.failures.length - a.failures.length
|
||||
|
||||
return a.lines - b.lines || a.moduleName.localeCompare(b.moduleName)
|
||||
})
|
||||
|
||||
lines.push('<details><summary>Module coverage</summary>')
|
||||
lines.push('')
|
||||
lines.push('| Module | Lines | Statements | Functions | Branches | Thresholds | Status |')
|
||||
lines.push('|---|---:|---:|---:|---:|---|---|')
|
||||
for (const row of moduleRows) {
|
||||
const thresholdLabel = row.thresholds
|
||||
? `L${row.thresholds.lines}/S${row.thresholds.statements}/F${row.thresholds.functions}/B${row.thresholds.branches}`
|
||||
: 'n/a'
|
||||
const status = row.thresholds ? (row.failures.length > 0 ? 'fail' : 'pass') : 'info'
|
||||
lines.push(`| ${row.moduleName} | ${row.lines.toFixed(2)}% | ${row.statements.toFixed(2)}% | ${row.functions.toFixed(2)}% | ${row.branches.toFixed(2)}% | ${thresholdLabel} | ${status} |`)
|
||||
}
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
|
||||
const changedRows = diffRows
|
||||
.filter(row => row.statements.total > 0 || row.branches.total > 0)
|
||||
.sort((a, b) => {
|
||||
@@ -297,59 +194,45 @@ function buildSummary({
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
|
||||
if (missingTestTouch) {
|
||||
lines.push(`Warning: tracked source files changed under \`web/app/components/\`, but no test files changed under \`web/app/components/**\` or \`web/__tests__/\`.`)
|
||||
if (STRICT_TEST_FILE_TOUCH)
|
||||
lines.push('`STRICT_COMPONENT_TEST_TOUCH=true` is enabled, so this warning fails the check.')
|
||||
lines.push('')
|
||||
}
|
||||
else {
|
||||
lines.push(`Relevant test files changed: ${changedTestFiles.length}`)
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (diffStatementFailures.length > 0) {
|
||||
lines.push('Uncovered changed statements:')
|
||||
for (const row of diffStatementFailures) {
|
||||
for (const row of diffStatementFailures)
|
||||
lines.push(`- ${row.file.replace('web/', '')}: ${formatLineRanges(row.statements.uncoveredLines)}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (diffBranchFailures.length > 0) {
|
||||
lines.push('Uncovered changed branches:')
|
||||
for (const row of diffBranchFailures) {
|
||||
for (const row of diffBranchFailures)
|
||||
lines.push(`- ${row.file.replace('web/', '')}: ${formatBranchRefs(row.branches.uncoveredBranches)}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (ignoredDiffLines.length > 0) {
|
||||
lines.push('Ignored changed lines via pragma:')
|
||||
for (const ignoredLine of ignoredDiffLines) {
|
||||
for (const ignoredLine of ignoredDiffLines)
|
||||
lines.push(`- ${ignoredLine.file.replace('web/', '')}:${ignoredLine.line} - ${ignoredLine.reason}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (invalidIgnorePragmas.length > 0) {
|
||||
lines.push('Invalid diff coverage ignore pragmas:')
|
||||
for (const invalidPragma of invalidIgnorePragmas) {
|
||||
for (const invalidPragma of invalidIgnorePragmas)
|
||||
lines.push(`- ${invalidPragma.file.replace('web/', '')}:${invalidPragma.line} - ${invalidPragma.reason}`)
|
||||
}
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
lines.push(`Changed source files checked: ${changedSourceFiles.length}`)
|
||||
lines.push(`Changed statement coverage: ${formatDiffPercent(diffTotals.statements)}`)
|
||||
lines.push(`Changed branch coverage: ${formatDiffPercent(diffTotals.branches)}`)
|
||||
lines.push('Blocking rules: uncovered changed statements, uncovered changed branches, invalid ignore pragmas.')
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
function buildSkipSummary(changedExcludedSourceFiles) {
|
||||
const lines = [
|
||||
'### app/components Diff Coverage',
|
||||
'### app/components Pure Diff Coverage',
|
||||
'',
|
||||
...buildDiffContextSummary(diffContext),
|
||||
'',
|
||||
`Excluded modules: \`${EXCLUDED_MODULES_LABEL}\``,
|
||||
`Excluded file kinds: \`${COMPONENT_COVERAGE_EXCLUDE_LABEL}\``,
|
||||
@@ -357,146 +240,60 @@ function buildSkipSummary(changedExcludedSourceFiles) {
|
||||
]
|
||||
|
||||
if (changedExcludedSourceFiles.length > 0) {
|
||||
lines.push('Only excluded component modules or type-only files changed, so diff coverage check was skipped.')
|
||||
lines.push('Only excluded component modules or type-only files changed, so pure diff coverage was skipped.')
|
||||
lines.push(`Skipped files: ${changedExcludedSourceFiles.length}`)
|
||||
}
|
||||
else {
|
||||
lines.push('No source changes under tracked `web/app/components/`. Diff coverage check skipped.')
|
||||
lines.push('No tracked source changes under `web/app/components/`. Pure diff coverage skipped.')
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
function getChangedFiles(base, head) {
|
||||
const output = execGit(['diff', '--name-only', '--diff-filter=ACMR', ...buildGitDiffRevisionArgs(base, head, DIFF_RANGE_MODE), '--', 'web/app/components', 'web/__tests__'])
|
||||
function buildDiffContextSummary(diffContext) {
|
||||
const lines = [
|
||||
`Compared \`${diffContext.base.slice(0, 12)}\` -> \`${diffContext.head.slice(0, 12)}\``,
|
||||
]
|
||||
|
||||
if (diffContext.useCombinedMergeDiff) {
|
||||
lines.push(`Requested diff range mode: \`${diffContext.requestedMode}\``)
|
||||
lines.push(`Effective diff strategy: \`combined-merge\` (${diffContext.reason})`)
|
||||
}
|
||||
else if (diffContext.reason) {
|
||||
lines.push(`Requested diff range mode: \`${diffContext.requestedMode}\``)
|
||||
lines.push(`Effective diff range mode: \`${diffContext.mode}\` (${diffContext.reason})`)
|
||||
}
|
||||
else {
|
||||
lines.push(`Diff range mode: \`${diffContext.mode}\``)
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
function getChangedFiles(diffContext) {
|
||||
if (diffContext.useCombinedMergeDiff) {
|
||||
const output = execGit(['diff-tree', '--cc', '--no-commit-id', '--name-only', '-r', diffContext.head, '--', APP_COMPONENTS_PREFIX])
|
||||
return output
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean)
|
||||
}
|
||||
|
||||
const output = execGit(['diff', '--name-only', '--diff-filter=ACMR', ...buildGitDiffRevisionArgs(diffContext.base, diffContext.head, diffContext.mode), '--', APP_COMPONENTS_PREFIX])
|
||||
return output
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean)
|
||||
}
|
||||
|
||||
function getChangedLineMap(base, head) {
|
||||
const diff = execGit(['diff', '--unified=0', '--no-color', '--diff-filter=ACMR', ...buildGitDiffRevisionArgs(base, head, DIFF_RANGE_MODE), '--', 'web/app/components'])
|
||||
return parseChangedLineMap(diff, isTrackedComponentSourceFile)
|
||||
}
|
||||
|
||||
function isAnyComponentSourceFile(filePath) {
|
||||
return filePath.startsWith(APP_COMPONENTS_PREFIX)
|
||||
&& /\.(?:ts|tsx)$/.test(filePath)
|
||||
&& !isTestLikePath(filePath)
|
||||
}
|
||||
|
||||
function isTrackedComponentSourceFile(filePath) {
|
||||
return isAnyComponentSourceFile(filePath)
|
||||
&& !isExcludedComponentSourceFile(filePath)
|
||||
}
|
||||
|
||||
function isExcludedComponentSourceFile(filePath) {
|
||||
return isAnyComponentSourceFile(filePath)
|
||||
&& (
|
||||
EXCLUDED_COMPONENT_MODULES.has(getModuleName(filePath))
|
||||
|| excludedComponentCoverageFiles.has(filePath)
|
||||
)
|
||||
}
|
||||
|
||||
function isRelevantTestFile(filePath) {
|
||||
return filePath.startsWith(SHARED_TEST_PREFIX)
|
||||
|| (filePath.startsWith(APP_COMPONENTS_PREFIX) && isTestLikePath(filePath) && !isExcludedComponentTestFile(filePath))
|
||||
}
|
||||
|
||||
function isExcludedComponentTestFile(filePath) {
|
||||
if (!filePath.startsWith(APP_COMPONENTS_PREFIX))
|
||||
return false
|
||||
|
||||
return EXCLUDED_COMPONENT_MODULES.has(getModuleName(filePath))
|
||||
}
|
||||
|
||||
function isTestLikePath(filePath) {
|
||||
return /(?:^|\/)__tests__\//.test(filePath)
|
||||
|| /(?:^|\/)__mocks__\//.test(filePath)
|
||||
|| /\.(?:spec|test)\.(?:ts|tsx)$/.test(filePath)
|
||||
|| /\.stories\.(?:ts|tsx)$/.test(filePath)
|
||||
|| /\.d\.ts$/.test(filePath)
|
||||
}
|
||||
|
||||
function getCoverageStats(entry) {
|
||||
const lineHits = getLineHits(entry)
|
||||
const statementHits = Object.values(entry.s ?? {})
|
||||
const functionHits = Object.values(entry.f ?? {})
|
||||
const branchHits = Object.values(entry.b ?? {}).flat()
|
||||
|
||||
return {
|
||||
lines: {
|
||||
covered: Object.values(lineHits).filter(count => count > 0).length,
|
||||
total: Object.keys(lineHits).length,
|
||||
},
|
||||
statements: {
|
||||
covered: statementHits.filter(count => count > 0).length,
|
||||
total: statementHits.length,
|
||||
},
|
||||
functions: {
|
||||
covered: functionHits.filter(count => count > 0).length,
|
||||
total: functionHits.length,
|
||||
},
|
||||
branches: {
|
||||
covered: branchHits.filter(count => count > 0).length,
|
||||
total: branchHits.length,
|
||||
},
|
||||
function getChangedLineMap(diffContext) {
|
||||
if (diffContext.useCombinedMergeDiff) {
|
||||
const diff = execGit(['diff-tree', '--cc', '--no-commit-id', '-r', '--unified=0', diffContext.head, '--', APP_COMPONENTS_PREFIX])
|
||||
return parseChangedLineMap(diff, filePath => isTrackedComponentSourceFile(filePath, context.excludedComponentCoverageFiles))
|
||||
}
|
||||
}
|
||||
|
||||
function sumCoverageStats(rows) {
|
||||
const total = createEmptyCoverageStats()
|
||||
for (const row of rows)
|
||||
addCoverageStats(total, row)
|
||||
return total
|
||||
}
|
||||
|
||||
function mergeCoverageStats(map, moduleName, stats) {
|
||||
const existing = map.get(moduleName) ?? createEmptyCoverageStats()
|
||||
addCoverageStats(existing, stats)
|
||||
map.set(moduleName, existing)
|
||||
}
|
||||
|
||||
function addCoverageStats(target, source) {
|
||||
for (const metric of ['lines', 'statements', 'functions', 'branches']) {
|
||||
target[metric].covered += source[metric].covered
|
||||
target[metric].total += source[metric].total
|
||||
}
|
||||
}
|
||||
|
||||
function createEmptyCoverageStats() {
|
||||
return {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
}
|
||||
}
|
||||
|
||||
function getThresholdFailures(stats, thresholds) {
|
||||
const failures = []
|
||||
for (const metric of ['lines', 'statements', 'functions', 'branches']) {
|
||||
const actual = percentage(stats[metric].covered, stats[metric].total)
|
||||
const expected = thresholds[metric]
|
||||
if (actual < expected) {
|
||||
failures.push({
|
||||
metric,
|
||||
actual,
|
||||
expected,
|
||||
})
|
||||
}
|
||||
}
|
||||
return failures
|
||||
}
|
||||
|
||||
function getModuleName(filePath) {
|
||||
const relativePath = filePath.slice(APP_COMPONENTS_PREFIX.length)
|
||||
if (!relativePath)
|
||||
return '(root)'
|
||||
|
||||
const segments = relativePath.split('/')
|
||||
return segments.length === 1 ? '(root)' : segments[0]
|
||||
const diff = execGit(['diff', '--unified=0', '--no-color', '--diff-filter=ACMR', ...buildGitDiffRevisionArgs(diffContext.base, diffContext.head, diffContext.mode), '--', APP_COMPONENTS_PREFIX])
|
||||
return parseChangedLineMap(diff, filePath => isTrackedComponentSourceFile(filePath, context.excludedComponentCoverageFiles))
|
||||
}
|
||||
|
||||
function formatLineRanges(lines) {
|
||||
@@ -536,10 +333,6 @@ function percentage(covered, total) {
|
||||
return (covered / total) * 100
|
||||
}
|
||||
|
||||
function formatPercent(metric) {
|
||||
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
|
||||
}
|
||||
|
||||
function formatDiffPercent(metric) {
|
||||
if (metric.total === 0)
|
||||
return 'n/a'
|
||||
|
||||
195
web/scripts/components-coverage-common.mjs
Normal file
195
web/scripts/components-coverage-common.mjs
Normal file
@@ -0,0 +1,195 @@
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
import { getLineHits, normalizeToRepoRelative } from './check-components-diff-coverage-lib.mjs'
|
||||
import { collectComponentCoverageExcludedFiles } from './component-coverage-filters.mjs'
|
||||
import { EXCLUDED_COMPONENT_MODULES } from './components-coverage-thresholds.mjs'
|
||||
|
||||
export const APP_COMPONENTS_ROOT = 'web/app/components'
|
||||
export const APP_COMPONENTS_PREFIX = `${APP_COMPONENTS_ROOT}/`
|
||||
export const APP_COMPONENTS_COVERAGE_PREFIX = 'app/components/'
|
||||
export const SHARED_TEST_PREFIX = 'web/__tests__/'
|
||||
|
||||
export function createComponentCoverageContext(repoRoot) {
|
||||
const webRoot = path.join(repoRoot, 'web')
|
||||
const excludedComponentCoverageFiles = new Set(
|
||||
collectComponentCoverageExcludedFiles(path.join(webRoot, 'app/components'), { pathPrefix: APP_COMPONENTS_ROOT }),
|
||||
)
|
||||
|
||||
return {
|
||||
excludedComponentCoverageFiles,
|
||||
repoRoot,
|
||||
webRoot,
|
||||
}
|
||||
}
|
||||
|
||||
export function loadTrackedCoverageEntries(coverage, context) {
|
||||
const coverageEntries = new Map()
|
||||
|
||||
for (const [file, entry] of Object.entries(coverage)) {
|
||||
const repoRelativePath = normalizeToRepoRelative(entry.path ?? file, {
|
||||
appComponentsCoveragePrefix: APP_COMPONENTS_COVERAGE_PREFIX,
|
||||
appComponentsPrefix: APP_COMPONENTS_PREFIX,
|
||||
repoRoot: context.repoRoot,
|
||||
sharedTestPrefix: SHARED_TEST_PREFIX,
|
||||
webRoot: context.webRoot,
|
||||
})
|
||||
|
||||
if (!isTrackedComponentSourceFile(repoRelativePath, context.excludedComponentCoverageFiles))
|
||||
continue
|
||||
|
||||
coverageEntries.set(repoRelativePath, entry)
|
||||
}
|
||||
|
||||
return coverageEntries
|
||||
}
|
||||
|
||||
export function collectTrackedComponentSourceFiles(context) {
|
||||
const trackedFiles = []
|
||||
|
||||
walkComponentSourceFiles(path.join(context.webRoot, 'app/components'), (absolutePath) => {
|
||||
const repoRelativePath = path.relative(context.repoRoot, absolutePath).split(path.sep).join('/')
|
||||
if (isTrackedComponentSourceFile(repoRelativePath, context.excludedComponentCoverageFiles))
|
||||
trackedFiles.push(repoRelativePath)
|
||||
})
|
||||
|
||||
trackedFiles.sort((a, b) => a.localeCompare(b))
|
||||
return trackedFiles
|
||||
}
|
||||
|
||||
export function isTestLikePath(filePath) {
|
||||
return /(?:^|\/)__tests__\//.test(filePath)
|
||||
|| /(?:^|\/)__mocks__\//.test(filePath)
|
||||
|| /\.(?:spec|test)\.(?:ts|tsx)$/.test(filePath)
|
||||
|| /\.stories\.(?:ts|tsx)$/.test(filePath)
|
||||
|| /\.d\.ts$/.test(filePath)
|
||||
}
|
||||
|
||||
export function getModuleName(filePath) {
|
||||
const relativePath = filePath.slice(APP_COMPONENTS_PREFIX.length)
|
||||
if (!relativePath)
|
||||
return '(root)'
|
||||
|
||||
const segments = relativePath.split('/')
|
||||
return segments.length === 1 ? '(root)' : segments[0]
|
||||
}
|
||||
|
||||
export function isAnyComponentSourceFile(filePath) {
|
||||
return filePath.startsWith(APP_COMPONENTS_PREFIX)
|
||||
&& /\.(?:ts|tsx)$/.test(filePath)
|
||||
&& !isTestLikePath(filePath)
|
||||
}
|
||||
|
||||
export function isExcludedComponentSourceFile(filePath, excludedComponentCoverageFiles) {
|
||||
return isAnyComponentSourceFile(filePath)
|
||||
&& (
|
||||
EXCLUDED_COMPONENT_MODULES.has(getModuleName(filePath))
|
||||
|| excludedComponentCoverageFiles.has(filePath)
|
||||
)
|
||||
}
|
||||
|
||||
export function isTrackedComponentSourceFile(filePath, excludedComponentCoverageFiles) {
|
||||
return isAnyComponentSourceFile(filePath)
|
||||
&& !isExcludedComponentSourceFile(filePath, excludedComponentCoverageFiles)
|
||||
}
|
||||
|
||||
export function isTrackedComponentTestFile(filePath) {
|
||||
return filePath.startsWith(APP_COMPONENTS_PREFIX)
|
||||
&& isTestLikePath(filePath)
|
||||
&& !EXCLUDED_COMPONENT_MODULES.has(getModuleName(filePath))
|
||||
}
|
||||
|
||||
export function isRelevantTestFile(filePath) {
|
||||
return filePath.startsWith(SHARED_TEST_PREFIX)
|
||||
|| isTrackedComponentTestFile(filePath)
|
||||
}
|
||||
|
||||
export function isAnyWebTestFile(filePath) {
|
||||
return filePath.startsWith('web/')
|
||||
&& isTestLikePath(filePath)
|
||||
}
|
||||
|
||||
export function getCoverageStats(entry) {
|
||||
const lineHits = getLineHits(entry)
|
||||
const statementHits = Object.values(entry.s ?? {})
|
||||
const functionHits = Object.values(entry.f ?? {})
|
||||
const branchHits = Object.values(entry.b ?? {}).flat()
|
||||
|
||||
return {
|
||||
lines: {
|
||||
covered: Object.values(lineHits).filter(count => count > 0).length,
|
||||
total: Object.keys(lineHits).length,
|
||||
},
|
||||
statements: {
|
||||
covered: statementHits.filter(count => count > 0).length,
|
||||
total: statementHits.length,
|
||||
},
|
||||
functions: {
|
||||
covered: functionHits.filter(count => count > 0).length,
|
||||
total: functionHits.length,
|
||||
},
|
||||
branches: {
|
||||
covered: branchHits.filter(count => count > 0).length,
|
||||
total: branchHits.length,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
export function sumCoverageStats(rows) {
|
||||
const total = createEmptyCoverageStats()
|
||||
for (const row of rows)
|
||||
addCoverageStats(total, row)
|
||||
return total
|
||||
}
|
||||
|
||||
export function mergeCoverageStats(map, moduleName, stats) {
|
||||
const existing = map.get(moduleName) ?? createEmptyCoverageStats()
|
||||
addCoverageStats(existing, stats)
|
||||
map.set(moduleName, existing)
|
||||
}
|
||||
|
||||
export function percentage(covered, total) {
|
||||
if (total === 0)
|
||||
return 100
|
||||
return (covered / total) * 100
|
||||
}
|
||||
|
||||
export function formatPercent(metric) {
|
||||
return `${percentage(metric.covered, metric.total).toFixed(2)}%`
|
||||
}
|
||||
|
||||
function createEmptyCoverageStats() {
|
||||
return {
|
||||
lines: { covered: 0, total: 0 },
|
||||
statements: { covered: 0, total: 0 },
|
||||
functions: { covered: 0, total: 0 },
|
||||
branches: { covered: 0, total: 0 },
|
||||
}
|
||||
}
|
||||
|
||||
function addCoverageStats(target, source) {
|
||||
for (const metric of ['lines', 'statements', 'functions', 'branches']) {
|
||||
target[metric].covered += source[metric].covered
|
||||
target[metric].total += source[metric].total
|
||||
}
|
||||
}
|
||||
|
||||
function walkComponentSourceFiles(currentDir, onFile) {
|
||||
if (!fs.existsSync(currentDir))
|
||||
return
|
||||
|
||||
const entries = fs.readdirSync(currentDir, { withFileTypes: true })
|
||||
for (const entry of entries) {
|
||||
const entryPath = path.join(currentDir, entry.name)
|
||||
if (entry.isDirectory()) {
|
||||
if (entry.name === '__tests__' || entry.name === '__mocks__')
|
||||
continue
|
||||
walkComponentSourceFiles(entryPath, onFile)
|
||||
continue
|
||||
}
|
||||
|
||||
if (!/\.(?:ts|tsx)$/.test(entry.name) || isTestLikePath(entry.name))
|
||||
continue
|
||||
|
||||
onFile(entryPath)
|
||||
}
|
||||
}
|
||||
165
web/scripts/report-components-coverage-baseline.mjs
Normal file
165
web/scripts/report-components-coverage-baseline.mjs
Normal file
@@ -0,0 +1,165 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
import { COMPONENT_COVERAGE_EXCLUDE_LABEL } from './component-coverage-filters.mjs'
|
||||
import {
|
||||
collectTrackedComponentSourceFiles,
|
||||
createComponentCoverageContext,
|
||||
formatPercent,
|
||||
getCoverageStats,
|
||||
getModuleName,
|
||||
loadTrackedCoverageEntries,
|
||||
mergeCoverageStats,
|
||||
percentage,
|
||||
sumCoverageStats,
|
||||
} from './components-coverage-common.mjs'
|
||||
import {
|
||||
COMPONENTS_GLOBAL_THRESHOLDS,
|
||||
EXCLUDED_COMPONENT_MODULES,
|
||||
getComponentModuleThreshold,
|
||||
} from './components-coverage-thresholds.mjs'
|
||||
|
||||
const EXCLUDED_MODULES_LABEL = [...EXCLUDED_COMPONENT_MODULES].sort().join(', ')
|
||||
|
||||
const repoRoot = repoRootFromCwd()
|
||||
const context = createComponentCoverageContext(repoRoot)
|
||||
const coverageFinalPath = path.join(context.webRoot, 'coverage', 'coverage-final.json')
|
||||
|
||||
if (!fs.existsSync(coverageFinalPath)) {
|
||||
console.error(`Coverage report not found at ${coverageFinalPath}`)
|
||||
process.exit(1)
|
||||
}
|
||||
|
||||
const coverage = JSON.parse(fs.readFileSync(coverageFinalPath, 'utf8'))
|
||||
const trackedSourceFiles = collectTrackedComponentSourceFiles(context)
|
||||
const coverageEntries = loadTrackedCoverageEntries(coverage, context)
|
||||
const fileCoverageRows = []
|
||||
const moduleCoverageMap = new Map()
|
||||
|
||||
for (const [file, entry] of coverageEntries.entries()) {
|
||||
const stats = getCoverageStats(entry)
|
||||
const moduleName = getModuleName(file)
|
||||
fileCoverageRows.push({ file, moduleName, ...stats })
|
||||
mergeCoverageStats(moduleCoverageMap, moduleName, stats)
|
||||
}
|
||||
|
||||
const overallCoverage = sumCoverageStats(fileCoverageRows)
|
||||
const overallTargetGaps = getTargetGaps(overallCoverage, COMPONENTS_GLOBAL_THRESHOLDS)
|
||||
const moduleCoverageRows = [...moduleCoverageMap.entries()]
|
||||
.map(([moduleName, stats]) => ({
|
||||
moduleName,
|
||||
stats,
|
||||
targets: getComponentModuleThreshold(moduleName),
|
||||
}))
|
||||
.map(row => ({
|
||||
...row,
|
||||
targetGaps: row.targets ? getTargetGaps(row.stats, row.targets) : [],
|
||||
}))
|
||||
.sort((a, b) => {
|
||||
const aWorst = Math.min(...a.targetGaps.map(gap => gap.delta), Number.POSITIVE_INFINITY)
|
||||
const bWorst = Math.min(...b.targetGaps.map(gap => gap.delta), Number.POSITIVE_INFINITY)
|
||||
return aWorst - bWorst || a.moduleName.localeCompare(b.moduleName)
|
||||
})
|
||||
|
||||
appendSummary(buildSummary({
|
||||
coverageEntriesCount: coverageEntries.size,
|
||||
moduleCoverageRows,
|
||||
overallCoverage,
|
||||
overallTargetGaps,
|
||||
trackedSourceFilesCount: trackedSourceFiles.length,
|
||||
}))
|
||||
|
||||
function buildSummary({
|
||||
coverageEntriesCount,
|
||||
moduleCoverageRows,
|
||||
overallCoverage,
|
||||
overallTargetGaps,
|
||||
trackedSourceFilesCount,
|
||||
}) {
|
||||
const lines = [
|
||||
'### app/components Baseline Coverage',
|
||||
'',
|
||||
`Excluded modules: \`${EXCLUDED_MODULES_LABEL}\``,
|
||||
`Excluded file kinds: \`${COMPONENT_COVERAGE_EXCLUDE_LABEL}\``,
|
||||
'',
|
||||
`Coverage entries: ${coverageEntriesCount}/${trackedSourceFilesCount} tracked source files`,
|
||||
'',
|
||||
'| Metric | Current | Target | Delta |',
|
||||
'|---|---:|---:|---:|',
|
||||
`| Lines | ${formatPercent(overallCoverage.lines)} | ${COMPONENTS_GLOBAL_THRESHOLDS.lines}% | ${formatDelta(overallCoverage.lines, COMPONENTS_GLOBAL_THRESHOLDS.lines)} |`,
|
||||
`| Statements | ${formatPercent(overallCoverage.statements)} | ${COMPONENTS_GLOBAL_THRESHOLDS.statements}% | ${formatDelta(overallCoverage.statements, COMPONENTS_GLOBAL_THRESHOLDS.statements)} |`,
|
||||
`| Functions | ${formatPercent(overallCoverage.functions)} | ${COMPONENTS_GLOBAL_THRESHOLDS.functions}% | ${formatDelta(overallCoverage.functions, COMPONENTS_GLOBAL_THRESHOLDS.functions)} |`,
|
||||
`| Branches | ${formatPercent(overallCoverage.branches)} | ${COMPONENTS_GLOBAL_THRESHOLDS.branches}% | ${formatDelta(overallCoverage.branches, COMPONENTS_GLOBAL_THRESHOLDS.branches)} |`,
|
||||
'',
|
||||
]
|
||||
|
||||
if (coverageEntriesCount !== trackedSourceFilesCount) {
|
||||
lines.push('Warning: coverage report did not include every tracked component source file. CI should set `VITEST_COVERAGE_SCOPE=app-components` before collecting coverage.')
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (overallTargetGaps.length > 0) {
|
||||
lines.push('Below baseline targets:')
|
||||
for (const gap of overallTargetGaps)
|
||||
lines.push(`- overall ${gap.metric}: ${gap.actual.toFixed(2)}% < ${gap.target}%`)
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
lines.push('<details><summary>Module baseline coverage</summary>')
|
||||
lines.push('')
|
||||
lines.push('| Module | Lines | Statements | Functions | Branches | Targets | Status |')
|
||||
lines.push('|---|---:|---:|---:|---:|---|---|')
|
||||
for (const row of moduleCoverageRows) {
|
||||
const targetsLabel = row.targets
|
||||
? `L${row.targets.lines}/S${row.targets.statements}/F${row.targets.functions}/B${row.targets.branches}`
|
||||
: 'n/a'
|
||||
const status = row.targets
|
||||
? (row.targetGaps.length > 0 ? 'below-target' : 'at-target')
|
||||
: 'unconfigured'
|
||||
lines.push(`| ${row.moduleName} | ${percentage(row.stats.lines.covered, row.stats.lines.total).toFixed(2)}% | ${percentage(row.stats.statements.covered, row.stats.statements.total).toFixed(2)}% | ${percentage(row.stats.functions.covered, row.stats.functions.total).toFixed(2)}% | ${percentage(row.stats.branches.covered, row.stats.branches.total).toFixed(2)}% | ${targetsLabel} | ${status} |`)
|
||||
}
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
lines.push('Report only: baseline targets no longer gate CI. The blocking rule is the pure diff coverage step.')
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
function getTargetGaps(stats, targets) {
|
||||
const gaps = []
|
||||
for (const metric of ['lines', 'statements', 'functions', 'branches']) {
|
||||
const actual = percentage(stats[metric].covered, stats[metric].total)
|
||||
const target = targets[metric]
|
||||
const delta = actual - target
|
||||
if (delta < 0) {
|
||||
gaps.push({
|
||||
actual,
|
||||
delta,
|
||||
metric,
|
||||
target,
|
||||
})
|
||||
}
|
||||
}
|
||||
return gaps
|
||||
}
|
||||
|
||||
function formatDelta(metric, target) {
|
||||
const actual = percentage(metric.covered, metric.total)
|
||||
const delta = actual - target
|
||||
const sign = delta >= 0 ? '+' : ''
|
||||
return `${sign}${delta.toFixed(2)}%`
|
||||
}
|
||||
|
||||
function appendSummary(lines) {
|
||||
const content = `${lines.join('\n')}\n`
|
||||
if (process.env.GITHUB_STEP_SUMMARY)
|
||||
fs.appendFileSync(process.env.GITHUB_STEP_SUMMARY, content)
|
||||
console.log(content)
|
||||
}
|
||||
|
||||
function repoRootFromCwd() {
|
||||
return execFileSync('git', ['rev-parse', '--show-toplevel'], {
|
||||
cwd: process.cwd(),
|
||||
encoding: 'utf8',
|
||||
}).trim()
|
||||
}
|
||||
168
web/scripts/report-components-test-touch.mjs
Normal file
168
web/scripts/report-components-test-touch.mjs
Normal file
@@ -0,0 +1,168 @@
|
||||
import { execFileSync } from 'node:child_process'
|
||||
import fs from 'node:fs'
|
||||
import {
|
||||
buildGitDiffRevisionArgs,
|
||||
normalizeDiffRangeMode,
|
||||
resolveGitDiffContext,
|
||||
} from './check-components-diff-coverage-lib.mjs'
|
||||
import {
|
||||
createComponentCoverageContext,
|
||||
isAnyWebTestFile,
|
||||
isRelevantTestFile,
|
||||
isTrackedComponentSourceFile,
|
||||
} from './components-coverage-common.mjs'
|
||||
|
||||
const REQUESTED_DIFF_RANGE_MODE = normalizeDiffRangeMode(process.env.DIFF_RANGE_MODE)
|
||||
|
||||
const repoRoot = repoRootFromCwd()
|
||||
const context = createComponentCoverageContext(repoRoot)
|
||||
const baseSha = process.env.BASE_SHA?.trim()
|
||||
const headSha = process.env.HEAD_SHA?.trim() || 'HEAD'
|
||||
|
||||
if (!baseSha || /^0+$/.test(baseSha)) {
|
||||
appendSummary([
|
||||
'### app/components Test Touch',
|
||||
'',
|
||||
'Skipped test-touch report because `BASE_SHA` was not available.',
|
||||
])
|
||||
process.exit(0)
|
||||
}
|
||||
|
||||
const diffContext = resolveGitDiffContext({
|
||||
base: baseSha,
|
||||
head: headSha,
|
||||
mode: REQUESTED_DIFF_RANGE_MODE,
|
||||
execGit,
|
||||
})
|
||||
const changedFiles = getChangedFiles(diffContext)
|
||||
const changedSourceFiles = changedFiles.filter(filePath => isTrackedComponentSourceFile(filePath, context.excludedComponentCoverageFiles))
|
||||
|
||||
if (changedSourceFiles.length === 0) {
|
||||
appendSummary([
|
||||
'### app/components Test Touch',
|
||||
'',
|
||||
...buildDiffContextSummary(diffContext),
|
||||
'',
|
||||
'No tracked source changes under `web/app/components/`. Test-touch report skipped.',
|
||||
])
|
||||
process.exit(0)
|
||||
}
|
||||
|
||||
const changedRelevantTestFiles = changedFiles.filter(isRelevantTestFile)
|
||||
const changedOtherWebTestFiles = changedFiles.filter(filePath => isAnyWebTestFile(filePath) && !isRelevantTestFile(filePath))
|
||||
const totalChangedWebTests = [...new Set([...changedRelevantTestFiles, ...changedOtherWebTestFiles])]
|
||||
|
||||
appendSummary(buildSummary({
|
||||
changedOtherWebTestFiles,
|
||||
changedRelevantTestFiles,
|
||||
diffContext,
|
||||
changedSourceFiles,
|
||||
totalChangedWebTests,
|
||||
}))
|
||||
|
||||
function buildSummary({
|
||||
changedOtherWebTestFiles,
|
||||
changedRelevantTestFiles,
|
||||
diffContext,
|
||||
changedSourceFiles,
|
||||
totalChangedWebTests,
|
||||
}) {
|
||||
const lines = [
|
||||
'### app/components Test Touch',
|
||||
'',
|
||||
...buildDiffContextSummary(diffContext),
|
||||
'',
|
||||
`Tracked source files changed: ${changedSourceFiles.length}`,
|
||||
`Component-local or shared integration tests changed: ${changedRelevantTestFiles.length}`,
|
||||
`Other web tests changed: ${changedOtherWebTestFiles.length}`,
|
||||
`Total changed web tests: ${totalChangedWebTests.length}`,
|
||||
'',
|
||||
]
|
||||
|
||||
if (totalChangedWebTests.length === 0) {
|
||||
lines.push('Warning: no frontend test files changed alongside tracked component source changes.')
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (changedRelevantTestFiles.length > 0) {
|
||||
lines.push('<details><summary>Changed component-local or shared tests</summary>')
|
||||
lines.push('')
|
||||
for (const filePath of changedRelevantTestFiles.slice(0, 40))
|
||||
lines.push(`- ${filePath.replace('web/', '')}`)
|
||||
if (changedRelevantTestFiles.length > 40)
|
||||
lines.push(`- ... ${changedRelevantTestFiles.length - 40} more`)
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
if (changedOtherWebTestFiles.length > 0) {
|
||||
lines.push('<details><summary>Changed other web tests</summary>')
|
||||
lines.push('')
|
||||
for (const filePath of changedOtherWebTestFiles.slice(0, 40))
|
||||
lines.push(`- ${filePath.replace('web/', '')}`)
|
||||
if (changedOtherWebTestFiles.length > 40)
|
||||
lines.push(`- ... ${changedOtherWebTestFiles.length - 40} more`)
|
||||
lines.push('</details>')
|
||||
lines.push('')
|
||||
}
|
||||
|
||||
lines.push('Report only: test-touch is now advisory and no longer blocks the diff coverage gate.')
|
||||
return lines
|
||||
}
|
||||
|
||||
function buildDiffContextSummary(diffContext) {
|
||||
const lines = [
|
||||
`Compared \`${diffContext.base.slice(0, 12)}\` -> \`${diffContext.head.slice(0, 12)}\``,
|
||||
]
|
||||
|
||||
if (diffContext.useCombinedMergeDiff) {
|
||||
lines.push(`Requested diff range mode: \`${diffContext.requestedMode}\``)
|
||||
lines.push(`Effective diff strategy: \`combined-merge\` (${diffContext.reason})`)
|
||||
}
|
||||
else if (diffContext.reason) {
|
||||
lines.push(`Requested diff range mode: \`${diffContext.requestedMode}\``)
|
||||
lines.push(`Effective diff range mode: \`${diffContext.mode}\` (${diffContext.reason})`)
|
||||
}
|
||||
else {
|
||||
lines.push(`Diff range mode: \`${diffContext.mode}\``)
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
function getChangedFiles(diffContext) {
|
||||
if (diffContext.useCombinedMergeDiff) {
|
||||
const output = execGit(['diff-tree', '--cc', '--no-commit-id', '--name-only', '-r', diffContext.head, '--', 'web'])
|
||||
return output
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean)
|
||||
}
|
||||
|
||||
const output = execGit(['diff', '--name-only', '--diff-filter=ACMR', ...buildGitDiffRevisionArgs(diffContext.base, diffContext.head, diffContext.mode), '--', 'web'])
|
||||
return output
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(Boolean)
|
||||
}
|
||||
|
||||
function appendSummary(lines) {
|
||||
const content = `${lines.join('\n')}\n`
|
||||
if (process.env.GITHUB_STEP_SUMMARY)
|
||||
fs.appendFileSync(process.env.GITHUB_STEP_SUMMARY, content)
|
||||
console.log(content)
|
||||
}
|
||||
|
||||
function execGit(args) {
|
||||
return execFileSync('git', args, {
|
||||
cwd: repoRoot,
|
||||
encoding: 'utf8',
|
||||
})
|
||||
}
|
||||
|
||||
function repoRootFromCwd() {
|
||||
return execFileSync('git', ['rev-parse', '--show-toplevel'], {
|
||||
cwd: process.cwd(),
|
||||
encoding: 'utf8',
|
||||
}).trim()
|
||||
}
|
||||
Reference in New Issue
Block a user