Compare commits

..

6 Commits

47 changed files with 4700 additions and 4833 deletions

View File

@@ -1,168 +0,0 @@
---
name: backend-code-review
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
---
# Backend Code Review
## When to use this skill
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
Do NOT use this skill when:
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
- The user is not asking for a review/analysis/improvement of backend code.
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
## How to use this skill
Follow these steps when using this skill:
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the users input. Keep the scope tight: review only what the user provided or explicitly referenced.
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
3. Compose the final output strictly follow the **Required Output Format**.
Notes when using this skill:
- Always include actionable fixes or suggestions (including possible code snippets).
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
## Checklist
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
## General Review Rules
### 1. Security Review
Check for:
- SQL injection vulnerabilities
- Server-Side Request Forgery (SSRF)
- Command injection
- Insecure deserialization
- Hardcoded secrets/credentials
- Improper authentication/authorization
- Insecure direct object references
### 2. Performance Review
Check for:
- N+1 queries
- Missing database indexes
- Memory leaks
- Blocking operations in async code
- Missing caching opportunities
### 3. Code Quality Review
Check for:
- Code forward compatibility
- Code duplication (DRY violations)
- Functions doing too much (SRP violations)
- Deep nesting / complex conditionals
- Magic numbers/strings
- Poor naming
- Missing error handling
- Incomplete type coverage
### 4. Testing Review
Check for:
- Missing test coverage for new code
- Tests that don't test behavior
- Flaky test patterns
- Missing edge cases
## Required Output Format
When this skill invoked, the response must exactly follow one of the two templates:
### Template A (any findings)
```markdown
# Code Review Summary
Found <X> critical issues need to be fixed:
## 🔴 Critical (Must Fix)
### 1. <brief description of the issue>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the issue>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each critical issue) ...
Found <Y> suggestions for improvement:
## 🟡 Suggestions (Should Consider)
### 1. <brief description of the suggestion>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<detailed explanation and references of the suggestion>
#### Suggested Fix
1. <brief description of suggested fix>
2. <code example> (optional, omit if not applicable)
---
... (repeat for each suggestion) ...
Found <Z> optional nits:
## 🟢 Nits (Optional)
### 1. <brief description of the nit>
FilePath: <path> line <line>
<relevant code snippet or pointer>
#### Explanation
<explanation and references of the optional nit>
#### Suggested Fix
- <minor suggestions>
---
... (repeat for each nits) ...
## ✅ What's Good
- <Positive feedback on good patterns>
```
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
- Don't compress the blank lines between sections; keep them as-is for readability.
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
### Template B (no issues)
```markdown
## Code Review Summary
✅ No issues found.
```

View File

@@ -1,91 +0,0 @@
# Rule Catalog — Architecture
## Scope
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
## Rules
### Keep business logic out of controllers
- Category: maintainability
- Severity: critical
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
- Example:
- Bad:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = request.get_json() or {}
if payload.get("force") and current_user.role != "admin":
raise ValueError("only admin can force publish")
app = App.query.get(app_id)
app.status = "published"
db.session.commit()
return {"result": "ok"}
```
- Good:
```python
@bp.post("/apps/<app_id>/publish")
def publish_app(app_id: str):
payload = PublishRequest.model_validate(request.get_json() or {})
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
return {"result": "ok"}
```
### Preserve layer dependency direction
- Category: best practices
- Severity: critical
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
- Example:
- Bad:
```python
# core/policy/publish_policy.py
from controllers.console.app import request_context
def can_publish() -> bool:
return request_context.current_user.is_admin
```
- Good:
```python
# core/policy/publish_policy.py
def can_publish(role: str) -> bool:
return role == "admin"
# service layer adapts web/user context to domain input
allowed = can_publish(role=current_user.role)
```
### Keep libs business-agnostic
- Category: maintainability
- Severity: critical
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
- Suggested fix:
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
- Example:
- Bad:
```python
# api/libs/conversation_filter.py
from services.conversation_service import ConversationService
def should_archive_conversation(conversation, tenant_id: str) -> bool:
# Domain policy and service dependency are leaking into libs.
service = ConversationService()
if service.has_paid_plan(tenant_id):
return conversation.idle_days > 90
return conversation.idle_days > 30
```
- Good:
```python
# api/libs/datetime_utils.py (business-agnostic helper)
def older_than_days(idle_days: int, threshold_days: int) -> bool:
return idle_days > threshold_days
# services/conversation_service.py (business logic stays in service/core)
from libs.datetime_utils import older_than_days
def should_archive_conversation(conversation, tenant_id: str) -> bool:
threshold_days = 90 if has_paid_plan(tenant_id) else 30
return older_than_days(conversation.idle_days, threshold_days)
```

View File

@@ -1,157 +0,0 @@
# Rule Catalog — DB Schema Design
## Scope
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
## Rules
### Do not query other tables inside `@property`
- Category: [maintainability, performance]
- Severity: critical
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
- Suggested fix:
- Keep model properties pure and local to already-loaded fields.
- Move cross-table data fetching to service/repository methods.
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
- Example:
- Bad:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def app_name(self) -> str:
with Session(db.engine, expire_on_commit=False) as session:
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
return app.name
```
- Good:
```python
class Conversation(TypeBase):
__tablename__ = "conversations"
@property
def display_title(self) -> str:
return self.name or "Untitled"
# Service/repository layer performs explicit batch fetch for related App rows.
```
### Prefer including `tenant_id` in model definitions
- Category: maintainability
- Severity: suggestion
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
- Suggested fix:
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
- Example:
- Bad:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
class Dataset(TypeBase):
__tablename__ = "datasets"
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
```
### Detect and avoid duplicate/redundant indexes
- Category: performance
- Severity: suggestion
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
- Suggested fix:
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
- Apply the same review standard in both model `__table_args__` and migration index DDL.
- Example:
- Bad:
```python
__table_args__ = (
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
- Good:
```python
__table_args__ = (
# Keep the wider index unless profiling proves a dedicated short index is needed.
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
)
```
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
- Category: maintainability
- Severity: critical
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
- Suggested fix:
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
- Example:
- Bad:
```python
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
```
- Good:
```python
from sqlalchemy.orm import Mapped
from models.types import AdjustedJSON
class ToolConfig(TypeBase):
__tablename__ = "tool_configs"
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
```
### Guard migration incompatibilities with dialect checks and shared types
- Category: maintainability
- Severity: critical
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
- Suggested fix:
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
- Example:
- Bad:
```python
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column(
"data_source_type",
sa.String(255),
server_default=sa.text("'database'::character varying"),
nullable=False,
)
)
```
- Good:
```python
def _is_pg(conn) -> bool:
return conn.dialect.name == "postgresql"
conn = op.get_bind()
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
batch_op.add_column(
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
)
```

View File

@@ -1,61 +0,0 @@
# Rule Catalog - Repositories Abstraction
## Scope
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
## Rules
### Introduce repositories abstraction
- Category: maintainability
- Severity: suggestion
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
- Suggested fix:
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
- Example:
- Bad:
```python
# Existing repository is ignored and service uses ad-hoc table queries.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.session.execute(
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
).scalar_one()
app.archived = True
self.session.commit()
```
- Good:
```python
# Case A: Existing repository must be reused for all table operations.
class AppService:
def archive_app(self, app_id: str, tenant_id: str) -> None:
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
app.archived = True
self.app_repo.save(app)
# If the query is missing, extend the existing abstraction.
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
```
- Bad:
```python
# No repository exists, but large-domain query logic is scattered in service code.
class ConversationService:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
# many filters/joins/pagination variants duplicated across services
```
- Good:
```python
# Case B: Introduce repository for large/complex domains or storage variation.
class ConversationRepository(Protocol):
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
class SqlAlchemyConversationRepository:
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
...
class ConversationService:
def __init__(self, conversation_repo: ConversationRepository):
self.conversation_repo = conversation_repo
```

View File

@@ -1,139 +0,0 @@
# Rule Catalog — SQLAlchemy Patterns
## Scope
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
## Rules
### Use Session context manager with explicit transaction control behavior
- Category: best practices
- Severity: critical
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
- Suggested fix:
- Use **explicit `session.commit()`** after completing a related write unit.
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
- Example:
- Bad:
```python
# Missing commit: write may never be persisted.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Long transaction: external I/O inside a DB transaction.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
call_external_api()
```
- Good:
```python
# Option 1: explicit commit.
with Session(db.engine, expire_on_commit=False) as session:
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
session.commit()
# Option 2: scoped transaction with automatic commit/rollback.
with Session(db.engine, expire_on_commit=False) as session, session.begin():
run = session.get(WorkflowRun, run_id)
run.status = "cancelled"
# Keep non-DB work outside transaction scope.
call_external_api()
```
### Enforce tenant_id scoping on shared-resource queries
- Category: security
- Severity: critical
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
- Example:
- Bad:
```python
stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
### Prefer SQLAlchemy expressions over raw SQL by default
- Category: maintainability
- Severity: suggestion
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
- Example:
- Bad:
```python
row = session.execute(
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
{"id": workflow_id, "tenant_id": tenant_id},
).first()
```
- Good:
```python
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
row = session.execute(stmt).scalar_one_or_none()
```
### Protect write paths with concurrency safeguards
- Category: quality
- Severity: critical
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
- Suggested fix:
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
- Example:
- Bad:
```python
# No tenant scope, no conflict detection, and no lock on a contested write path.
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
session.commit() # silently overwrites concurrent updates
```
- Good:
```python
# 1) Optimistic lock (low contention, retry on conflict)
result = session.execute(
update(WorkflowRun)
.where(
WorkflowRun.id == run_id,
WorkflowRun.tenant_id == tenant_id,
WorkflowRun.version == expected_version,
)
.values(status="cancelled", version=WorkflowRun.version + 1)
)
if result.rowcount == 0:
raise WorkflowStateConflictError("stale version, retry")
# 2) Redis distributed lock (cross-worker critical section)
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
with redis_client.lock(lock_name, timeout=20):
session.execute(
update(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.values(status="cancelled")
)
session.commit()
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
run = session.execute(
select(WorkflowRun)
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
.with_for_update()
).scalar_one()
run.status = "cancelled"
session.commit()
```

View File

@@ -1 +0,0 @@
../../.agents/skills/backend-code-review

View File

@@ -77,7 +77,14 @@ jobs:
}
const body = diff.trim()
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
? `### Pyrefly Diff
<details>
<summary>base → PR</summary>
\`\`\`diff
${diff}
\`\`\`
</details>`
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({

View File

@@ -74,16 +74,14 @@ jobs:
}
const body = diff.trim()
? [
'### Pyrefly Diff',
'<details>',
'<summary>base → PR</summary>',
'',
'```diff',
diff,
'```',
'</details>',
].join('\n')
? `### Pyrefly Diff
<details>
<summary>base → PR</summary>
\`\`\`diff
${diff}
\`\`\`
</details>`
: '### Pyrefly Diff\nNo changes detected.';
await github.rest.issues.createComment({

View File

@@ -3,22 +3,14 @@ name: Web Tests
on:
workflow_call:
permissions:
contents: read
concurrency:
group: web-tests-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
test:
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
name: Web Tests
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
defaults:
run:
shell: bash
@@ -47,58 +39,7 @@ jobs:
run: pnpm install --frozen-lockfile
- name: Run tests
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
- name: Upload blob report
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v6
with:
name: blob-report-${{ matrix.shardIndex }}
path: web/.vitest-reports/*
include-hidden-files: true
retention-days: 1
merge-reports:
name: Merge Test Reports
if: ${{ !cancelled() }}
needs: [test]
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./web
steps:
- name: Checkout code
uses: actions/checkout@v6
with:
persist-credentials: false
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
run: pnpm install --frozen-lockfile
- name: Download blob reports
uses: actions/download-artifact@v6
with:
path: web/.vitest-reports
pattern: blob-report-*
merge-multiple: true
- name: Merge reports
run: pnpm vitest --merge-reports --coverage --silent=passed-only
run: pnpm test:ci
- name: Coverage Summary
if: always()

View File

@@ -50,6 +50,7 @@ forbidden_modules =
allow_indirect_imports = True
ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
@@ -105,6 +106,9 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.model_manager
core.workflow.nodes.agent.agent_node -> core.provider_manager
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.datasource.datasource_node -> models.model
core.workflow.nodes.datasource.datasource_node -> models.tools
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
@@ -142,6 +146,8 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
@@ -154,6 +160,7 @@ ignore_imports =
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
@@ -190,6 +197,7 @@ ignore_imports =
core.workflow.nodes.code.code_node -> core.variables.segments
core.workflow.nodes.code.code_node -> core.variables.types
core.workflow.nodes.code.entities -> core.variables.types
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
core.workflow.nodes.document_extractor.node -> core.variables
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
@@ -232,6 +240,7 @@ ignore_imports =
core.workflow.variable_loader -> core.variables.consts
core.workflow.workflow_type_encoder -> core.variables
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database

View File

@@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
1. Set up your application by visiting `http://localhost:3000`.
1. Start the worker service (async and scheduler tasks, runs from `api`).
1. Optional: start the worker service (async tasks, runs from `api`).
```bash
./dev/start-worker
@@ -54,6 +54,86 @@ The scripts resolve paths relative to their location, so you can run them from a
./dev/start-beat
```
### Manual commands
<details>
<summary>Show manual setup and run steps</summary>
These commands assume you start from the repository root.
1. Start the docker-compose stack.
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
```bash
cp docker/middleware.env.example docker/middleware.env
# Use mysql or another vector database profile if you are not using postgres/weaviate.
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
```
1. Copy env files.
```bash
cp api/.env.example api/.env
cp web/.env.example web/.env.local
```
1. Install UV if needed.
```bash
pip install uv
# Or on macOS
brew install uv
```
1. Install API dependencies.
```bash
cd api
uv sync --group dev
```
1. Install web dependencies.
```bash
cd web
pnpm install
cd ..
```
1. Start backend (runs migrations first, in a new terminal).
```bash
cd api
uv run flask db upgrade
uv run flask run --host 0.0.0.0 --port=5001 --debug
```
1. Start Dify [web](../web) service (in a new terminal).
```bash
cd web
pnpm dev:inspect
```
1. Set up your application by visiting `http://localhost:3000`.
1. Optional: start the worker service (async tasks, in a new terminal).
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
```bash
cd api
uv run celery -A app.celery beat
```
</details>
### Environment notes
> [!IMPORTANT]

View File

@@ -133,7 +133,6 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel):
email: EmailStr
language: str | None = None
phase: str | None = None
token: str | None = None
@@ -547,13 +546,17 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.token is not None:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@@ -573,7 +576,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
phase=send_phase,
)
return {"result": "success", "data": token}
@@ -608,12 +611,26 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
phase_transitions: dict[str, str] = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@@ -643,13 +660,22 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@@ -669,14 +669,16 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
yield from ()
return
yield # Make this a generator
def _handle_annotation_reply_event(
self, event: QueueAnnotationReplyEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
yield from ()
return
yield # Make this a generator
def _handle_message_replace_event(
self, event: QueueMessageReplaceEvent, **kwargs

View File

@@ -122,7 +122,7 @@ class AppQueueManager(ABC):
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue
:param event:

View File

@@ -5,7 +5,6 @@ from typing_extensions import override
from configs import dify_config
from core.app.llm.model_access import build_dify_model_access
from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
@@ -19,7 +18,6 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
from core.workflow.nodes.code.entities import CodeLanguage
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
@@ -180,15 +178,6 @@ class DifyNodeFactory(NodeFactory):
model_factory=self._llm_model_factory,
)
if node_type == NodeType.DATASOURCE:
return DatasourceNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
datasource_manager=DatasourceManager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,

View File

@@ -1,39 +1,16 @@
import logging
from collections.abc import Generator
from threading import Lock
from typing import Any, cast
from sqlalchemy import select
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.db.session_factory import session_factory
from core.plugin.impl.datasource import PluginDatasourceManager
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.file import File
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@@ -126,238 +103,3 @@ class DatasourceManager:
tenant_id,
datasource_type,
).get_datasource(datasource_name)
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
datasource_runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
return datasource_runtime.get_icon_url(tenant_id)
@classmethod
def stream_online_results(
cls,
*,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[DatasourceMessage, None, Any]:
"""
Pull-based streaming of domain messages from datasource plugins.
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
"""
ds_type = DatasourceProviderType.value_of(datasource_type)
runtime = cls.get_datasource_runtime(
provider_id=provider_id,
datasource_name=datasource_name,
tenant_id=tenant_id,
datasource_type=ds_type,
)
dsp_service = DatasourceProviderService()
credentials = dsp_service.get_datasource_credentials(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
if credentials:
doc_runtime.runtime.credentials = credentials
if datasource_param is None:
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
user_id=user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_param.workspace_id,
page_id=datasource_param.page_id,
type=datasource_param.type,
),
provider_type=ds_type,
)
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
if credentials:
drive_runtime.runtime.credentials = credentials
if online_drive_request is None:
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
inner_gen = drive_runtime.online_drive_download_file(
user_id=user_id,
request=OnlineDriveDownloadFileRequest(
id=online_drive_request.id,
bucket=online_drive_request.bucket,
),
provider_type=ds_type,
)
else:
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
# Bridge through to caller while preserving generator return contract
yield from inner_gen
# No structured final data here; node/adapter will assemble outputs
return {}
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
ds_type = DatasourceProviderType.value_of(datasource_type)
messages = cls.stream_online_results(
user_id=user_id,
datasource_name=datasource_name,
datasource_type=datasource_type,
provider_id=provider_id,
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
credential_id=credential_id,
datasource_param=datasource_param,
online_drive_request=online_drive_request,
)
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
)
variables: dict[str, Any] = {}
file_out: File | None = None
for message in transformed:
mtype = message.type
if mtype in {
DatasourceMessage.MessageType.IMAGE_LINK,
DatasourceMessage.MessageType.BINARY_LINK,
DatasourceMessage.MessageType.IMAGE,
}:
wanted_ds_type = ds_type in {
DatasourceProviderType.ONLINE_DRIVE,
DatasourceProviderType.ONLINE_DOCUMENT,
}
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
url = message.message.text
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with session_factory.create_session() as session:
stmt = select(ToolFile).where(
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
)
datasource_file = session.scalar(stmt)
if not datasource_file:
raise ValueError(
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
)
mime_type = datasource_file.mimetype
if datasource_file is not None:
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(mime_type),
"transfer_method": FileTransferMethod.TOOL_FILE,
"url": url,
}
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
elif mtype == DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
elif mtype == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
yield StreamChunkEvent(
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
)
elif mtype == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
name = message.message.variable_name
value = message.message.variable_value
if message.message.stream:
assert isinstance(value, str), "stream variable_value must be str"
variables[name] = variables.get(name, "") + value
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
else:
variables[name] = value
elif mtype == DatasourceMessage.MessageType.FILE:
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
f = message.meta.get("file")
if isinstance(f, File):
file_out = f
else:
pass
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
variable_pool.add([node_id, "file"], file_out)
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={**variables},
)
)
else:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file_out,
"datasource_type": ds_type,
},
)
)
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
return file_info

View File

@@ -379,11 +379,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
"""
id: str = Field(..., description="The id of the file")
bucket: str = Field("", description="The name of the bucket")
@field_validator("bucket", mode="before")
@classmethod
def _coerce_bucket(cls, v) -> str:
if v is None:
return ""
return str(v)
bucket: str | None = Field(None, description="The name of the bucket")

View File

@@ -1,26 +1,40 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from typing import Any, cast
from core.datasource.entities.datasource_entities import DatasourceProviderType
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.datasource.entities.datasource_entities import (
DatasourceMessage,
DatasourceParameter,
DatasourceProviderType,
GetOnlineDocumentPageContentRequest,
OnlineDriveDownloadFileRequest,
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.file import File
from core.workflow.file.enums import FileTransferMethod, FileType
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.repositories.datasource_manager_protocol import (
DatasourceManagerProtocol,
DatasourceParameter,
OnlineDriveDownloadFileParam,
)
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile
from models.tools import ToolFile
from services.datasource_provider_service import DatasourceProviderService
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from .entities import DatasourceNodeData
from .exc import DatasourceNodeError
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
from .exc import DatasourceNodeError, DatasourceParameterError
class DatasourceNode(Node[DatasourceNodeData]):
@@ -31,22 +45,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
node_type = NodeType.DATASOURCE
execution_type = NodeExecutionType.ROOT
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
datasource_manager: DatasourceManagerProtocol,
):
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self.datasource_manager = datasource_manager
def _run(self) -> Generator:
"""
Run the datasource node
@@ -54,69 +52,84 @@ class DatasourceNode(Node[DatasourceNodeData]):
node_data = self.node_data
variable_pool = self.graph_runtime_state.variable_pool
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segment:
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
if not datasource_type_segement:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segment:
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
if not datasource_info_segement:
raise DatasourceNodeError("Datasource info is not set")
datasource_info_value = datasource_info_segment.value
datasource_info_value = datasource_info_segement.value
if not isinstance(datasource_info_value, dict):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_type = DatasourceProviderType.value_of(datasource_type)
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
datasource_info["icon"] = self.datasource_manager.get_icon_url(
provider_id=provider_id,
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=datasource_type.value,
datasource_type=datasource_type,
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
try:
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_datasource_credentials(
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=datasource_info.get("credential_id", ""),
)
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
# Build typed request objects
datasource_parameters = None
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
datasource_parameters = DatasourceParameter(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_document_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.get_online_document_page_content(
user_id=self.user_id,
datasource_parameters=GetOnlineDocumentPageContentRequest(
workspace_id=datasource_info.get("workspace_id", ""),
page_id=datasource_info.get("page", {}).get("page_id", ""),
type=datasource_info.get("page", {}).get("type", ""),
),
provider_type=datasource_type,
)
online_drive_request = None
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
online_drive_request = OnlineDriveDownloadFileParam(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket", ""),
)
yield from self._transform_message(
messages=online_document_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
)
case DatasourceProviderType.ONLINE_DRIVE:
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
if credentials:
datasource_runtime.runtime.credentials = credentials
online_drive_result: Generator[DatasourceMessage, None, None] = (
datasource_runtime.online_drive_download_file(
user_id=self.user_id,
request=OnlineDriveDownloadFileRequest(
id=datasource_info.get("id", ""),
bucket=datasource_info.get("bucket"),
),
provider_type=datasource_type,
)
credential_id = datasource_info.get("credential_id", "")
yield from self.datasource_manager.stream_node_events(
node_id=self._node_id,
user_id=self.user_id,
datasource_name=node_data.datasource_name or "",
datasource_type=datasource_type.value,
provider_id=provider_id,
tenant_id=self.tenant_id,
provider=node_data.provider_name,
plugin_id=node_data.plugin_id,
credential_id=credential_id,
)
yield from self._transform_datasource_file_message(
messages=online_drive_result,
parameters_for_log=parameters_for_log,
datasource_info=datasource_info,
variable_pool=variable_pool,
datasource_param=datasource_parameters,
online_drive_request=online_drive_request,
datasource_type=datasource_type,
)
case DatasourceProviderType.WEBSITE_CRAWL:
yield StreamCompletedEvent(
@@ -134,9 +147,23 @@ class DatasourceNode(Node[DatasourceNodeData]):
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
file_info = self.datasource_manager.get_upload_file_by_id(
file_id=related_id, tenant_id=self.tenant_id
file_info = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=self.tenant_id,
type=FileType.CUSTOM,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
url=upload_file.source_url,
)
variable_pool.add([self._node_id, "file"], file_info)
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
@@ -174,6 +201,55 @@ class DatasourceNode(Node[DatasourceNodeData]):
)
)
def _generate_parameters(
self,
*,
datasource_parameters: Sequence[DatasourceParameter],
variable_pool: VariablePool,
node_data: DatasourceNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
result: dict[str, Any] = {}
if node_data.datasource_parameters:
for parameter_name in node_data.datasource_parameters:
parameter = datasource_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
datasource_input = node_data.datasource_parameters[parameter_name]
if datasource_input.type == "variable":
variable = variable_pool.get(datasource_input.value)
if variable is None:
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
parameter_value = variable.value
elif datasource_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(datasource_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@@ -211,6 +287,206 @@ class DatasourceNode(Node[DatasourceNodeData]):
return result
def _transform_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict | list] = []
variables: dict[str, Any] = {}
for message in message_stream:
match message.type:
case (
DatasourceMessage.MessageType.IMAGE_LINK
| DatasourceMessage.MessageType.BINARY_LINK
| DatasourceMessage.MessageType.IMAGE
):
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
files.append(file)
case DatasourceMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, DatasourceMessage.TextMessage)
assert message.meta
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
mapping = {
"tool_file_id": datasource_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
)
case DatasourceMessage.MessageType.TEXT:
assert isinstance(message.message, DatasourceMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=message.message.text,
is_final=False,
)
case DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
json.append(message.message.json_object)
case DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=stream_text,
is_final=False,
)
case DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[self._node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
case DatasourceMessage.MessageType.FILE:
assert message.meta is not None
files.append(message.meta["file"])
case (
DatasourceMessage.MessageType.BLOB_CHUNK
| DatasourceMessage.MessageType.LOG
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
):
pass
# mark the end of the stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={**variables},
metadata={
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def version(cls) -> str:
return "1"
def _transform_datasource_file_message(
self,
messages: Generator[DatasourceMessage, None, None],
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: VariablePool,
datasource_type: DatasourceProviderType,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
messages=messages,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
)
file = None
for message in message_stream:
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)
url = message.message.text
transfer_method = FileTransferMethod.TOOL_FILE
datasource_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
datasource_file = session.scalar(stmt)
if datasource_file is None:
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
mapping = {
"tool_file_id": datasource_file_id,
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id,
)
if file:
variable_pool.add([self._node_id, "file"], file)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
outputs={
"file": file,
"datasource_type": datasource_type,
},
)
)

View File

@@ -1,50 +0,0 @@
from collections.abc import Generator
from typing import Any, Protocol
from pydantic import BaseModel
from core.workflow.file import File
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
class DatasourceParameter(BaseModel):
workspace_id: str
page_id: str
type: str
class OnlineDriveDownloadFileParam(BaseModel):
id: str
bucket: str
class DatasourceFinal(BaseModel):
data: dict[str, Any] | None = None
class DatasourceManagerProtocol(Protocol):
@classmethod
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ...
@classmethod
def stream_node_events(
cls,
*,
node_id: str,
user_id: str,
datasource_name: str,
datasource_type: str,
provider_id: str,
tenant_id: str,
provider: str,
plugin_id: str,
credential_id: str,
parameters_for_log: dict[str, Any],
datasource_info: dict[str, Any],
variable_pool: Any,
datasource_param: DatasourceParameter | None = None,
online_drive_request: OnlineDriveDownloadFileParam | None = None,
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ...
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ...

View File

@@ -4,6 +4,7 @@ import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from hashlib import sha256
from typing import Any, cast
@@ -80,12 +81,25 @@ class TokenPair(BaseModel):
csrf_token: str
class ChangeEmailPhase(StrEnum):
OLD = "old_email"
OLD_VERIFIED = "old_email_verified"
NEW = "new_email"
NEW_VERIFIED = "new_email_verified"
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@@ -535,13 +549,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
code, token = cls.generate_change_email_token(
account_email,
account,
old_email=old_email,
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
)
send_change_mail_task.delay(
language=language,

View File

@@ -1,42 +0,0 @@
from collections.abc import Generator
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.workflow.node_events import StreamCompletedEvent
def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
# produce a streamed variable "a"="xy"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.VARIABLE,
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="x", stream=True),
meta=None,
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.VARIABLE,
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="y", stream=True),
meta=None,
)
def test_stream_node_events_accumulates_variables(mocker):
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
events = list(
DatasourceManager.stream_node_events(
node_id="A",
user_id="u",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={},
datasource_info={"user_id": "u"},
variable_pool=mocker.Mock(),
datasource_param=type("P", (), {"workspace_id": "w", "page_id": "pg", "type": "t"})(),
online_drive_request=None,
)
)
assert isinstance(events[-1], StreamCompletedEvent)

View File

@@ -1,84 +0,0 @@
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
class _Seg:
def __init__(self, v):
self.value = v
class _VarPool:
def __init__(self, data):
self.data = data
def get(self, path):
d = self.data
for k in path:
d = d[k]
return _Seg(d)
def add(self, *_a, **_k):
pass
class _GS:
def __init__(self, vp):
self.variable_pool = vp
class _GP:
tenant_id = "t1"
app_id = "app-1"
workflow_id = "wf-1"
graph_config = {}
user_id = "u1"
user_from = "account"
invoke_from = "debugger"
call_depth = 0
def test_node_integration_minimal_stream(mocker):
sys_d = {
"sys": {
"datasource_type": "online_document",
"datasource_info": {"workspace_id": "w", "page": {"page_id": "pg", "type": "t"}, "credential_id": ""},
}
}
vp = _VarPool(sys_d)
class _Mgr:
@classmethod
def get_icon_url(cls, **_):
return "icon"
@classmethod
def stream_node_events(cls, **_):
yield from ()
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
@classmethod
def get_upload_file_by_id(cls, **_):
raise AssertionError
node = DatasourceNode(
id="n",
config={
"id": "n",
"data": {
"type": "datasource",
"version": "1",
"title": "Datasource",
"provider_type": "plugin",
"provider_name": "p",
"plugin_id": "plug",
"datasource_name": "ds",
},
},
graph_init_params=_GP(),
graph_runtime_state=_GS(vp),
datasource_manager=_Mgr,
)
out = list(node._run())
assert isinstance(out[-1], StreamCompletedEvent)

View File

@@ -1,418 +0,0 @@
"""Integration tests for SQL-oriented DatasetService scenarios.
This suite migrates SQL-backed behaviors from the old unit suite to real
container-backed integration tests. The tests exercise real ORM persistence and
only patch non-DB collaborators when needed.
"""
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.model_runtime.entities.model_entities import ModelType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
from services.errors.dataset import DatasetNameDuplicateError
class DatasetServiceIntegrationDataFactory:
"""Factory for creating real database entities used by integration tests."""
@staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
"""Create an account and tenant, then bind the account as current tenant member."""
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add_all([account, tenant])
db.session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db.session.add(join)
db.session.flush()
# Keep tenant context on the in-memory user without opening a separate session.
account.role = role
account._current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(
tenant_id: str,
created_by: str,
name: str = "Test Dataset",
description: str | None = "Test description",
provider: str = "vendor",
indexing_technique: str | None = "high_quality",
permission: str = DatasetPermissionEnum.ONLY_ME,
retrieval_model: dict | None = None,
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
chunk_structure: str | None = None,
) -> Dataset:
"""Create a dataset record with configurable SQL fields."""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,
permission=permission,
retrieval_model=retrieval_model,
embedding_model_provider=embedding_model_provider,
embedding_model=embedding_model,
collection_binding_id=collection_binding_id,
chunk_structure=chunk_structure,
)
db.session.add(dataset)
db.session.flush()
return dataset
@staticmethod
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document:
"""Create a document row belonging to the given dataset."""
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
data_source_info='{"upload_file_id": "upload-file-id"}',
batch=str(uuid4()),
name=name,
created_from="web",
created_by=created_by,
indexing_status="completed",
doc_form="text_model",
)
db.session.add(document)
db.session.flush()
return document
@staticmethod
def create_embedding_model(provider: str = "openai", model_name: str = "text-embedding-ada-002") -> Mock:
"""Create a fake embedding model object for external provider boundary patching."""
embedding_model = Mock()
embedding_model.provider = provider
embedding_model.model_name = model_name
return embedding_model
class TestDatasetServiceCreateDataset:
"""Integration coverage for DatasetService.create_empty_dataset."""
def test_create_internal_dataset_basic_success(self, db_session_with_containers):
"""Create a basic internal dataset with minimal configuration."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="Basic Internal Dataset",
description="Test description",
indexing_technique=None,
account=account,
)
# Assert
created_dataset = db.session.get(Dataset, result.id)
assert created_dataset is not None
assert created_dataset.provider == "vendor"
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_dataset.embedding_model_provider is None
assert created_dataset.embedding_model is None
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers):
"""Create an internal dataset with economy indexing and no embedding model."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
# Act
result = DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="Economy Dataset",
description=None,
indexing_technique="economy",
account=account,
)
# Assert
db.session.refresh(result)
assert result.indexing_technique == "economy"
assert result.embedding_model_provider is None
assert result.embedding_model is None
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers):
"""Create a high-quality dataset and persist embedding model settings."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
# Act
with patch("services.dataset_service.ModelManager") as mock_model_manager:
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
result = DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="High Quality Dataset",
description=None,
indexing_technique="high_quality",
account=account,
)
# Assert
db.session.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model_name
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
tenant_id=tenant.id,
model_type=ModelType.TEXT_EMBEDDING,
)
def test_create_dataset_duplicate_name_error(self, db_session_with_containers):
"""Raise duplicate-name error when the same tenant already has the name."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
name="Duplicate Dataset",
indexing_technique=None,
)
# Act / Assert
with pytest.raises(DatasetNameDuplicateError):
DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="Duplicate Dataset",
description=None,
indexing_technique=None,
account=account,
)
def test_create_external_dataset_success(self, db_session_with_containers):
"""Create an external dataset and persist external knowledge binding."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
external_knowledge_api_id = str(uuid4())
external_knowledge_id = "knowledge-123"
# Act
with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api:
mock_get_api.return_value = Mock(id=external_knowledge_api_id)
result = DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="External Dataset",
description=None,
indexing_technique=None,
account=account,
provider="external",
external_knowledge_api_id=external_knowledge_api_id,
external_knowledge_id=external_knowledge_id,
)
# Assert
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
assert result.provider == "external"
assert binding is not None
assert binding.external_knowledge_id == external_knowledge_id
assert binding.external_knowledge_api_id == external_knowledge_api_id
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers):
"""Create a high-quality dataset with retrieval/reranking settings."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=True,
reranking_model=RerankingModel(
reranking_provider_name="cohere",
reranking_model_name="rerank-english-v2.0",
),
top_k=3,
score_threshold_enabled=True,
score_threshold=0.6,
)
# Act
with (
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
):
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
result = DatasetService.create_empty_dataset(
tenant_id=tenant.id,
name="Dataset With Reranking",
description=None,
indexing_technique="high_quality",
account=account,
retrieval_model=retrieval_model,
)
# Assert
db.session.refresh(result)
assert result.retrieval_model == retrieval_model.model_dump()
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
class TestDatasetServiceUpdateAndDeleteDataset:
"""Integration coverage for SQL-backed update and delete behavior."""
def test_update_dataset_duplicate_name_error(self, db_session_with_containers):
"""Reject update when target name already exists within the same tenant."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
name="Source Dataset",
)
DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
name="Existing Dataset",
)
# Act / Assert
with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
"""Delete a dataset that already has documents."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
chunk_structure="text_model",
)
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id)
# Act
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
result = DatasetService.delete_dataset(dataset.id, account)
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_empty_dataset_success(self, db_session_with_containers):
"""Delete a dataset that has no documents and no indexing technique."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
indexing_technique=None,
chunk_structure=None,
)
# Act
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
result = DatasetService.delete_dataset(dataset.id, account)
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
"""Delete dataset when indexing_technique is None but doc_form path still exists."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
indexing_technique=None,
chunk_structure="text_model",
)
# Act
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
result = DatasetService.delete_dataset(dataset.id, account)
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
class TestDatasetServiceRetrievalConfiguration:
"""Integration coverage for retrieval configuration persistence."""
def test_get_dataset_retrieval_configuration(self, db_session_with_containers):
"""Return retrieval configuration that is persisted in SQL."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
retrieval_model = {
"search_method": "semantic_search",
"top_k": 5,
"score_threshold": 0.5,
"reranking_enable": True,
}
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
retrieval_model=retrieval_model,
)
# Act
result = DatasetService.get_dataset(dataset.id)
# Assert
assert result is not None
assert result.retrieval_model == retrieval_model
assert result.retrieval_model["search_method"] == "semantic_search"
assert result.retrieval_model["top_k"] == 5
def test_update_dataset_retrieval_configuration(self, db_session_with_containers):
"""Persist retrieval configuration updates through DatasetService.update_dataset."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=str(uuid4()),
)
update_data = {
"indexing_technique": "high_quality",
"retrieval_model": {
"search_method": "full_text_search",
"top_k": 10,
"score_threshold": 0.7,
},
}
# Act
result = DatasetService.update_dataset(dataset.id, update_data, account)
# Assert
db.session.refresh(dataset)
assert result.id == dataset.id
assert dataset.retrieval_model == update_data["retrieval_model"]

View File

@@ -1,464 +0,0 @@
"""
Integration tests for document_indexing_sync_task using testcontainers.
This module validates SQL-backed behavior for document sync flows:
- Notion sync precondition checks
- Segment cleanup and document state updates
- Credential and indexing error handling
"""
import json
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
@pytest.fixture(autouse=True)
def _register_dict_adapter_for_psycopg2():
"""Align test DB adapter behavior with dict payloads used in task update flow."""
register_adapter(dict, Json)
class DocumentIndexingSyncTaskTestDataFactory:
"""Create real DB entities for document indexing sync integration tests."""
@staticmethod
def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]:
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.flush()
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
return account, tenant
@staticmethod
def create_dataset(db_session_with_containers, tenant_id: str, created_by: str) -> Dataset:
dataset = Dataset(
tenant_id=tenant_id,
name=f"dataset-{uuid4()}",
description="sync test dataset",
data_source_type="notion_import",
indexing_technique="high_quality",
created_by=created_by,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_document(
db_session_with_containers,
*,
tenant_id: str,
dataset_id: str,
created_by: str,
data_source_info: dict | None,
indexing_status: str = "completed",
) -> Document:
document = Document(
tenant_id=tenant_id,
dataset_id=dataset_id,
position=0,
data_source_type="notion_import",
data_source_info=json.dumps(data_source_info) if data_source_info is not None else None,
batch="test-batch",
name=f"doc-{uuid4()}",
created_from="notion_import",
created_by=created_by,
indexing_status=indexing_status,
enabled=True,
doc_form="text_model",
doc_language="en",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@staticmethod
def create_segments(
db_session_with_containers,
*,
tenant_id: str,
dataset_id: str,
document_id: str,
created_by: str,
count: int = 3,
) -> list[DocumentSegment]:
segments: list[DocumentSegment] = []
for i in range(count):
segment = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
position=i,
content=f"segment-{i}",
answer=None,
word_count=10,
tokens=5,
index_node_id=f"node-{document_id}-{i}",
status="completed",
created_by=created_by,
)
db_session_with_containers.add(segment)
segments.append(segment)
db_session_with_containers.commit()
return segments
class TestDocumentIndexingSyncTask:
"""Integration tests for document_indexing_sync_task with real database assertions."""
@pytest.fixture
def mock_external_dependencies(self):
"""Patch only external collaborators; keep DB access real."""
with (
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_datasource_service_class,
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_notion_extractor_class,
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_index_processor_factory,
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_indexing_runner_class,
):
datasource_service = Mock()
datasource_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
mock_datasource_service_class.return_value = datasource_service
notion_extractor = Mock()
notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_notion_extractor_class.return_value = notion_extractor
index_processor = Mock()
index_processor.clean = Mock()
mock_index_processor_factory.return_value.init_index_processor.return_value = index_processor
indexing_runner = Mock(spec=IndexingRunner)
indexing_runner.run = Mock()
mock_indexing_runner_class.return_value = indexing_runner
yield {
"datasource_service": datasource_service,
"notion_extractor": notion_extractor,
"notion_extractor_class": mock_notion_extractor_class,
"index_processor": index_processor,
"index_processor_factory": mock_index_processor_factory,
"indexing_runner": indexing_runner,
}
def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None):
account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
)
notion_info = data_source_info or {
"notion_workspace_id": str(uuid4()),
"notion_page_id": str(uuid4()),
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
"credential_id": str(uuid4()),
}
document = DocumentIndexingSyncTaskTestDataFactory.create_document(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=account.id,
data_source_info=notion_info,
indexing_status="completed",
)
segments = DocumentIndexingSyncTaskTestDataFactory.create_segments(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
created_by=account.id,
count=3,
)
return {
"account": account,
"tenant": tenant,
"dataset": dataset,
"document": document,
"segments": segments,
"node_ids": [segment.index_node_id for segment in segments],
"notion_info": notion_info,
}
def test_document_not_found(self, db_session_with_containers, mock_external_dependencies):
"""Test that task handles missing document gracefully."""
# Arrange
dataset_id = str(uuid4())
document_id = str(uuid4())
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
context = self._create_notion_sync_context(
db_session_with_containers,
data_source_info={
"notion_page_id": str(uuid4()),
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
},
)
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
context = self._create_notion_sync_context(
db_session_with_containers,
data_source_info={
"notion_workspace_id": str(uuid4()),
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
},
)
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies):
"""Test that task raises error when data_source_info is empty."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
{"data_source_info": None}
)
db_session_with_containers.commit()
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(context["dataset"].id, context["document"].id)
def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies):
"""Test that task sets document error state when credential is missing."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
mock_external_dependencies["datasource_service"].get_datasource_credentials.return_value = None
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "error"
assert "Datasource credential not found" in updated_document.error
assert updated_document.stopped_at is not None
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies):
"""Test that task exits early when notion page is unchanged."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
mock_external_dependencies["notion_extractor"].get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == "completed"
assert updated_document.processing_started_at is None
assert remaining_segments == 3
mock_external_dependencies["index_processor"].clean.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_not_called()
def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies):
"""Test full successful sync flow with SQL state updates and side effects."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z"
assert remaining_segments == 0
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
assert clean_call_args is not None
clean_args, clean_kwargs = clean_call_args
assert getattr(clean_args[0], "id", None) == context["dataset"].id
assert set(clean_args[1]) == set(context["node_ids"])
assert clean_kwargs.get("with_keywords") is True
assert clean_kwargs.get("delete_child_chunks") is True
run_call_args = mock_external_dependencies["indexing_runner"].run.call_args
assert run_call_args is not None
run_documents = run_call_args[0][0]
assert len(run_documents) == 1
assert getattr(run_documents[0], "id", None) == context["document"].id
def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies):
"""Test that task still updates document and reindexes if dataset vanishes before clean."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
def _delete_dataset_before_clean() -> str:
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
db_session_with_containers.commit()
return "2024-01-02T00:00:00Z"
mock_external_dependencies[
"notion_extractor"
].get_notion_last_edited_time.side_effect = _delete_dataset_before_clean
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
mock_external_dependencies["index_processor"].clean.assert_not_called()
mock_external_dependencies["indexing_runner"].run.assert_called_once()
def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies):
"""Test that indexing continues when index cleanup fails."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
mock_external_dependencies["index_processor"].clean.side_effect = Exception("Cleaning error")
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
remaining_segments = (
db_session_with_containers.query(DocumentSegment)
.where(DocumentSegment.document_id == context["document"].id)
.count()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert remaining_segments == 0
mock_external_dependencies["indexing_runner"].run.assert_called_once()
def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies):
"""Test that DocumentIsPausedError does not flip document into error state."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
mock_external_dependencies["indexing_runner"].run.side_effect = DocumentIsPausedError("Document paused")
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "parsing"
assert updated_document.error is None
def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies):
"""Test that indexing errors are persisted to document state."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
mock_external_dependencies["indexing_runner"].run.side_effect = Exception("Indexing error")
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
db_session_with_containers.expire_all()
updated_document = (
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
)
assert updated_document is not None
assert updated_document.indexing_status == "error"
assert "Indexing error" in updated_document.error
assert updated_document.stopped_at is not None
def test_index_processor_clean_called_with_correct_params(
self,
db_session_with_containers,
mock_external_dependencies,
):
"""Test that clean is called with dataset instance and collected node ids."""
# Arrange
context = self._create_notion_sync_context(db_session_with_containers)
# Act
document_indexing_sync_task(context["dataset"].id, context["document"].id)
# Assert
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
assert clean_call_args is not None
clean_args, clean_kwargs = clean_call_args
assert getattr(clean_args[0], "id", None) == context["dataset"].id
assert set(clean_args[1]) == set(context["node_ids"])
assert clean_kwargs.get("with_keywords") is True
assert clean_kwargs.get("delete_child_chunks") is True

View File

@@ -77,7 +77,7 @@ def _restx_mask_defaults(app: Flask):
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
service_result = [{"entrypoint": "main:agent"}]
service_result = {"entrypoint": "main:agent"}
service_mock = MagicMock(return_value=service_result)
monkeypatch.setattr(
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.auth.error import InvalidTokenError
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
ChangeEmailCheckApi,
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
def test_should_infer_new_email_phase_from_token(
self,
mock_features,
mock_csrf,
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.db")
@patch("controllers.console.workspace.account.Session")
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_account_by_email,
mock_session_cls,
mock_account_db,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
existing_account = _build_account("old@example.com", "acc-old")
mock_get_account_by_email.return_value = existing_account
mock_send_email.return_value = "token-legacy"
mock_session = MagicMock()
mock_session_cm = MagicMock()
mock_session_cm.__enter__.return_value = mock_session
mock_session_cm.__exit__.return_value = None
mock_session_cls.return_value = mock_session_cm
mock_account_db.engine = MagicMock()
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
assert response == {"result": "success", "data": "token-legacy"}
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=existing_account,
email="old@example.com",
old_email="old@example.com",
language="en-US",
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_unverified_old_email_token_for_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
mock_send_email.assert_not_called()
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
"user@example.com",
code="1234",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_refresh_new_email_phase_to_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("old@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "new@example.com",
"code": "5678",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
}
mock_generate_token.return_value = (None, "new-phase-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
mock_is_rate_limit.assert_called_once_with("new@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-456")
mock_generate_token.assert_called_once_with(
"new@example.com",
code="5678",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("new@example.com")
mock_csrf.assert_called_once()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "new@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_old_phase_token_for_reset(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_mismatched_new_email_for_verified_token(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "another@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-789"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")

View File

@@ -1,135 +0,0 @@
import types
from collections.abc import Generator
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
message=DatasourceMessage.TextMessage(text=text),
meta=None,
)
def test_get_icon_url_calls_runtime(mocker):
fake_runtime = mocker.Mock()
fake_runtime.get_icon_url.return_value = "https://icon"
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
url = DatasourceManager.get_icon_url(
provider_id="p/x",
tenant_id="t1",
datasource_name="ds",
datasource_type="online_document",
)
assert url == "https://icon"
DatasourceManager.get_datasource_runtime.assert_called_once()
def test_stream_online_results_yields_messages_online_document(mocker):
# stub runtime to yield a text message
def _doc_messages(**_):
yield from _gen_messages_text_only("hello")
fake_runtime = mocker.Mock()
fake_runtime.get_online_document_page_content.side_effect = _doc_messages
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
mocker.patch(
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
return_value=None,
)
gen = DatasourceManager.stream_online_results(
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
online_drive_request=None,
)
msgs = list(gen)
assert len(msgs) == 1
assert msgs[0].message.text == "hello"
def test_stream_node_events_emits_events_online_document(mocker):
# make manager's low-level stream produce TEXT only
mocker.patch.object(
DatasourceManager,
"stream_online_results",
return_value=_gen_messages_text_only("hello"),
)
events = list(
DatasourceManager.stream_node_events(
node_id="nodeA",
user_id="u1",
datasource_name="ds",
datasource_type="online_document",
provider_id="p/x",
tenant_id="t1",
provider="prov",
plugin_id="plug",
credential_id="",
parameters_for_log={"k": "v"},
datasource_info={"user_id": "u1"},
variable_pool=mocker.Mock(),
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
online_drive_request=None,
)
)
# should contain one StreamChunkEvent then a final chunk (empty) and a completed event
assert isinstance(events[0], StreamChunkEvent)
assert events[0].chunk == "hello"
assert isinstance(events[-1], StreamCompletedEvent)
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
def test_get_upload_file_by_id_builds_file(mocker):
# fake UploadFile row
fake_row = types.SimpleNamespace(
id="fid",
name="f",
extension="txt",
mime_type="text/plain",
size=1,
key="k",
source_url="http://x",
)
class _Q:
def __init__(self, row):
self._row = row
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._row
class _S:
def __init__(self, row):
self._row = row
def __enter__(self):
return self
def __exit__(self, *exc):
return False
def query(self, *_):
return _Q(self._row)
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
assert f.related_id == "fid"
assert f.extension == ".txt"

View File

@@ -1,93 +0,0 @@
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
class _VarSeg:
def __init__(self, v):
self.value = v
class _VarPool:
def __init__(self, mapping):
self._m = mapping
def get(self, selector):
d = self._m
for k in selector:
d = d[k]
return _VarSeg(d)
def add(self, *_args, **_kwargs):
pass
class _GraphState:
def __init__(self, var_pool):
self.variable_pool = var_pool
class _GraphParams:
tenant_id = "t1"
app_id = "app-1"
workflow_id = "wf-1"
graph_config = {}
user_id = "u1"
user_from = "account"
invoke_from = "debugger"
call_depth = 0
def test_datasource_node_delegates_to_manager_stream(mocker):
# prepare sys variables
sys_vars = {
"sys": {
"datasource_type": "online_document",
"datasource_info": {
"workspace_id": "w",
"page": {"page_id": "pg", "type": "t"},
"credential_id": "",
},
}
}
var_pool = _VarPool(sys_vars)
gs = _GraphState(var_pool)
gp = _GraphParams()
# stub manager class
class _Mgr:
@classmethod
def get_icon_url(cls, **_):
return "icon"
@classmethod
def stream_node_events(cls, **_):
yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False)
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
@classmethod
def get_upload_file_by_id(cls, **_):
raise AssertionError("not called")
node = DatasourceNode(
id="n",
config={
"id": "n",
"data": {
"type": "datasource",
"version": "1",
"title": "Datasource",
"provider_type": "plugin",
"provider_name": "p",
"plugin_id": "plug",
"datasource_name": "ds",
},
},
graph_init_params=gp,
graph_runtime_state=gs,
datasource_manager=_Mgr,
)
evts = list(node._run())
assert isinstance(evts[0], StreamChunkEvent)
assert isinstance(evts[-1], StreamCompletedEvent)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,12 @@
"""
Unit tests for collaborator parameter wiring in document_indexing_sync_task.
Unit tests for document indexing sync task.
These tests intentionally stay in unit scope because they validate call arguments
for external collaborators rather than SQL-backed state transitions.
This module tests the document indexing sync task functionality including:
- Syncing Notion documents when updated
- Validating document and data source existence
- Credential validation and retrieval
- Cleaning old segments before re-indexing
- Error handling and edge cases
"""
import uuid
@@ -10,92 +14,187 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from models.dataset import Dataset, Document
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def dataset_id() -> str:
"""Generate a dataset id."""
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_id() -> str:
"""Generate a document id."""
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_workspace_id() -> str:
"""Generate a notion workspace id."""
def document_id():
"""Generate a unique document ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_page_id() -> str:
"""Generate a notion page id."""
def notion_workspace_id():
"""Generate a Notion workspace ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def credential_id() -> str:
"""Generate a credential id."""
def notion_page_id():
"""Generate a Notion page ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_dataset(dataset_id):
"""Create a minimal dataset mock used by the task pre-check."""
def credential_id():
"""Generate a credential ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, credential_id):
"""Create a minimal notion document mock for collaborator parameter assertions."""
document = Mock(spec=Document)
document.id = document_id
document.dataset_id = dataset_id
document.tenant_id = str(uuid.uuid4())
document.data_source_type = "notion_import"
document.indexing_status = "completed"
document.doc_form = "text_model"
document.data_source_info_dict = {
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
"""Create a mock Document object with Notion data source."""
doc = Mock(spec=Document)
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.data_source_type = "notion_import"
doc.indexing_status = "completed"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
doc.data_source_info_dict = {
"notion_workspace_id": notion_workspace_id,
"notion_page_id": notion_page_id,
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
"credential_id": credential_id,
}
return document
return doc
@pytest.fixture
def mock_db_session(mock_document, mock_dataset):
"""Mock session_factory.create_session to drive deterministic read-only task flow."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory:
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
def mock_document_segments(document_id):
"""Create mock DocumentSegment objects."""
segments = []
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = document_id
segment.index_node_id = f"node-{document_id}-{i}"
segments.append(segment)
return segments
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = False
session.begin.return_value = begin_cm
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session().
mock_session_factory.create_session.return_value = session_cm
yield session
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
sessions = []
# Shared query mocks - all sessions use these
shared_query = MagicMock()
shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
# Create custom first mock that auto-cycles side_effect
class CyclicMock(MagicMock):
def __setattr__(self, name, value):
if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
shared_query.where.return_value.first = CyclicMock()
shared_filter_by.first = CyclicMock()
def _create_session():
"""Create a new mock session for each create_session() call."""
session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(exc_type, exc, tb):
# commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
@pytest.fixture
def mock_datasource_provider_service():
"""Mock datasource credential provider."""
"""Mock DatasourceProviderService."""
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
@@ -105,16 +204,314 @@ def mock_datasource_provider_service():
@pytest.fixture
def mock_notion_extractor():
"""Mock notion extractor class and instance."""
"""Mock NotionExtractor."""
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
mock_extractor_class.return_value = mock_extractor
yield {"class": mock_extractor_class, "instance": mock_extractor}
yield mock_extractor
class TestDocumentIndexingSyncTaskCollaboratorParams:
"""Unit tests for collaborator parameter passing in document_indexing_sync_task."""
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner.run = Mock()
mock_runner_class.return_value = mock_runner
yield mock_runner
# ============================================================================
# Tests for document_indexing_sync_task
# ============================================================================
class TestDocumentIndexingSyncTask:
"""Tests for the document_indexing_sync_task function."""
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
"""Test that task handles document not found gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert - at least one session should have been closed
assert mock_db_session.any_close_called()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when data_source_info is empty."""
# Arrange
mock_document.data_source_info_dict = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_credential_not_found(
self,
mock_db_session,
mock_datasource_provider_service,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_page_not_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# At least one session should have been closed via context manager teardown
assert mock_db_session.any_close_called()
def test_successful_sync_when_page_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
# Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Verify document status was updated to parsing
assert mock_document.indexing_status == "parsing"
assert mock_document.processing_started_at is not None
# Verify segments were cleaned
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
# Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations (across any created session)
assert mock_db_session.any_commit_called()
assert mock_db_session.any_close_called()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_indexing_runner,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# At least one session should be closed after error
assert mock_db_session.any_close_called()
def test_cleaning_error_continues_to_indexing(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
dataset_id,
document_id,
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
assert mock_db_session.any_close_called()
def test_indexing_runner_document_paused_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after handling error
assert mock_db_session.any_close_called()
def test_indexing_runner_general_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that general exceptions during indexing are handled."""
# Arrange
from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after error
assert mock_db_session.any_close_called()
def test_notion_extractor_initialized_with_correct_params(
self,
@@ -127,21 +524,27 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
notion_workspace_id,
notion_page_id,
):
"""Test that NotionExtractor is initialized with expected arguments."""
"""Test that NotionExtractor is initialized with correct parameters."""
# Arrange
expected_token = "test_token"
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
# Act
document_indexing_sync_task(dataset_id, document_id)
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
# Assert
mock_notion_extractor["class"].assert_called_once_with(
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_page_id,
notion_page_type="page",
notion_access_token=expected_token,
tenant_id=mock_document.tenant_id,
)
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_extractor_class.assert_called_once_with(
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_page_id,
notion_page_type="page",
notion_access_token="test_token",
tenant_id=mock_document.tenant_id,
)
def test_datasource_credentials_requested_correctly(
self,
@@ -153,16 +556,17 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
document_id,
credential_id,
):
"""Test that datasource credentials are requested with expected identifiers."""
"""Test that datasource credentials are requested with correct parameters."""
# Arrange
expected_tenant_id = mock_document.tenant_id
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=expected_tenant_id,
tenant_id=mock_document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
@@ -177,14 +581,16 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
dataset_id,
document_id,
):
"""Test that missing credential_id is forwarded as None."""
"""Test that task handles missing credential_id by passing None."""
# Arrange
mock_document.data_source_info_dict = {
"notion_workspace_id": "workspace-id",
"notion_page_id": "page-id",
"notion_workspace_id": "ws123",
"notion_page_id": "page123",
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
@@ -196,3 +602,39 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_index_processor_clean_called_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
mock_processor.clean.assert_called_once_with(
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
)

6
api/uv.lock generated
View File

@@ -5049,11 +5049,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.7.4"
version = "6.7.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" }
sdist = { url = "https://files.pythonhosted.org/packages/ff/63/3437c4363483f2a04000a48f1cd48c40097f69d580363712fa8b0b4afe45/pypdf-6.7.1.tar.gz", hash = "sha256:6b7a63be5563a0a35d54c6d6b550d75c00b8ccf36384be96365355e296e6b3b0", size = 5302208, upload-time = "2026-02-17T17:00:48.88Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" },
{ url = "https://files.pythonhosted.org/packages/68/77/38bd7744bb9e06d465b0c23879e6d2c187d93a383f8fa485c862822bb8a3/pypdf-6.7.1-py3-none-any.whl", hash = "sha256:a02ccbb06463f7c334ce1612e91b3e68a8e827f3cee100b9941771e6066b094e", size = 331048, upload-time = "2026-02-17T17:00:46.991Z" },
]
[[package]]

View File

@@ -33,7 +33,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t
cp .env.example .env.local
```
```txt
```
# For production release, change this to PRODUCTION
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
# The deployment edition, SELF_HOSTED

View File

@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}, 1000)
}
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
const sendEmail = async (email: string, token?: string) => {
try {
const res = await sendVerifyCode({
email,
phase: isOrigin ? 'old_email' : 'new_email',
token,
})
startCount()
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
const sendCodeToOriginEmail = async () => {
await sendEmail(
email,
true,
)
setStep(STEP.verifyOrigin)
}
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}
await sendEmail(
mail,
false,
stepToken,
)
setStep(STEP.verifyNew)

View File

@@ -61,7 +61,8 @@ const ParamsConfig = ({
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isCurrentRerankModelValid) {
&& !isCurrentRerankModelValid
) {
errMsg = t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })
}
}

View File

@@ -43,7 +43,7 @@ describe('Input component', () => {
it('shows left icon when showLeftIcon is true', () => {
render(<Input showLeftIcon />)
const searchIcon = document.querySelector('.i-ri-search-line')
const searchIcon = document.querySelector('svg')
expect(searchIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText('Search')
expect(input).toHaveClass('pl-[26px]')
@@ -51,7 +51,7 @@ describe('Input component', () => {
it('shows clear icon when showClearIcon is true and has value', () => {
render(<Input showClearIcon value="test" />)
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
const clearIcon = document.querySelector('.group svg')
expect(clearIcon).toBeInTheDocument()
const input = screen.getByDisplayValue('test')
expect(input).toHaveClass('pr-[26px]')
@@ -59,21 +59,21 @@ describe('Input component', () => {
it('does not show clear icon when disabled, even with value', () => {
render(<Input showClearIcon value="test" disabled />)
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
const clearIcon = document.querySelector('.group svg')
expect(clearIcon).not.toBeInTheDocument()
})
it('calls onClear when clear icon is clicked', () => {
const onClear = vi.fn()
render(<Input showClearIcon value="test" onClear={onClear} />)
const clearIconContainer = screen.getByTestId('input-clear')
const clearIconContainer = document.querySelector('.group')
fireEvent.click(clearIconContainer!)
expect(onClear).toHaveBeenCalledTimes(1)
})
it('shows warning icon when destructive is true', () => {
render(<Input destructive />)
const warningIcon = document.querySelector('.i-ri-error-warning-line')
const warningIcon = document.querySelector('svg')
expect(warningIcon).toBeInTheDocument()
const input = screen.getByPlaceholderText('Please input')
expect(input).toHaveClass('border-components-input-border-destructive')

View File

@@ -1,5 +1,6 @@
import type { VariantProps } from 'class-variance-authority'
import type { ChangeEventHandler, CSSProperties, FocusEventHandler } from 'react'
import { RiCloseCircleFill, RiErrorWarningLine, RiSearchLine } from '@remixicon/react'
import { cva } from 'class-variance-authority'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
@@ -12,8 +13,8 @@ export const inputVariants = cva(
{
variants: {
size: {
regular: 'px-3 system-sm-regular radius-md',
large: 'px-4 system-md-regular radius-lg',
regular: 'px-3 radius-md system-sm-regular',
large: 'px-4 radius-lg system-md-regular',
},
},
defaultVariants: {
@@ -82,7 +83,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
}
return (
<div className={cn('relative w-full', wrapperClassName)}>
{showLeftIcon && <span className={cn('i-ri-search-line absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
{showLeftIcon && <RiSearchLine className={cn('absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
<input
ref={ref}
style={styleCss}
@@ -114,11 +115,11 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
onClick={onClear}
data-testid="input-clear"
>
<span className="i-ri-close-circle-fill h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
<RiCloseCircleFill className="h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
</div>
)}
{destructive && (
<span className="i-ri-error-warning-line absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
<RiErrorWarningLine className="absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
)}
{showCopyIcon && (
<div className={cn('group absolute right-0 top-1/2 -translate-y-1/2 cursor-pointer')}>
@@ -130,7 +131,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
)}
{
unit && (
<div className="absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary system-sm-regular">
<div className="system-sm-regular absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary">
{unit}
</div>
)

View File

@@ -334,10 +334,10 @@ describe('FileList', () => {
it('should call resetKeywords prop when clear button is clicked', () => {
const mockResetKeywords = vi.fn()
const props = createDefaultProps({ resetKeywords: mockResetKeywords, keywords: 'to-reset' })
render(<FileList {...props} />)
const { container } = render(<FileList {...props} />)
// Act - Click the clear icon div (it contains RiCloseCircleFill icon)
const clearButton = screen.getByTestId('input-clear')
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
expect(clearButton).toBeInTheDocument()
fireEvent.click(clearButton!)
@@ -346,12 +346,12 @@ describe('FileList', () => {
it('should reset inputValue to empty string when clear is clicked', () => {
const props = createDefaultProps({ keywords: 'to-be-reset' })
render(<FileList {...props} />)
const { container } = render(<FileList {...props} />)
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
fireEvent.change(input, { target: { value: 'some-search' } })
// Act - Find and click the clear icon
const clearButton = screen.getByTestId('input-clear')
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
expect(clearButton).toBeInTheDocument()
fireEvent.click(clearButton!)

View File

@@ -93,8 +93,8 @@ describe('Header', () => {
const { container } = render(<Header {...props} />)
// Assert - Input should have search icon class
const searchIcon = container.querySelector('.i-ri-search-line.h-4.w-4')
// Assert - Input should have search icon (RiSearchLine is rendered as svg)
const searchIcon = container.querySelector('svg.h-4.w-4')
expect(searchIcon).toBeInTheDocument()
})
@@ -313,10 +313,10 @@ describe('Header', () => {
inputValue: 'to-clear',
handleResetKeywords: mockHandleResetKeywords,
})
render(<Header {...props} />)
const { container } = render(<Header {...props} />)
// Act - Find and click the clear icon container
const clearButton = screen.getByTestId('input-clear')
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
expect(clearButton).toBeInTheDocument()
fireEvent.click(clearButton!)
@@ -325,19 +325,19 @@ describe('Header', () => {
it('should not show clear icon when inputValue is empty', () => {
const props = createDefaultProps({ inputValue: '' })
render(<Header {...props} />)
const { container } = render(<Header {...props} />)
// Act & Assert - Clear icon should not be visible
const clearIcon = screen.queryByTestId('input-clear')
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
expect(clearIcon).not.toBeInTheDocument()
})
it('should show clear icon when inputValue is not empty', () => {
const props = createDefaultProps({ inputValue: 'some-value' })
render(<Header {...props} />)
const { container } = render(<Header {...props} />)
// Act & Assert - Clear icon should be visible
const clearIcon = screen.getByTestId('input-clear')
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
expect(clearIcon).toBeInTheDocument()
})
})
@@ -570,12 +570,13 @@ describe('Header', () => {
inputValue: 'to-clear',
handleResetKeywords: mockHandleResetKeywords,
})
const { rerender } = render(<Header {...props} />)
const { container, rerender } = render(<Header {...props} />)
// Act - Click clear, rerender, click again
fireEvent.click(screen.getByTestId('input-clear'))
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
fireEvent.click(clearButton!)
rerender(<Header {...props} />)
fireEvent.click(screen.getByTestId('input-clear'))
fireEvent.click(clearButton!)
expect(mockHandleResetKeywords).toHaveBeenCalledTimes(2)
})

View File

@@ -1,8 +1,24 @@
'use client'
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
import {
RiBrain2Fill,
RiBrain2Line,
RiCloseLine,
RiColorFilterFill,
RiColorFilterLine,
RiDatabase2Fill,
RiDatabase2Line,
RiGroup2Fill,
RiGroup2Line,
RiMoneyDollarCircleFill,
RiMoneyDollarCircleLine,
RiPuzzle2Fill,
RiPuzzle2Line,
RiTranslate2,
} from '@remixicon/react'
import { useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import SearchInput from '@/app/components/base/search-input'
import Input from '@/app/components/base/input'
import BillingPage from '@/app/components/billing/billing-page'
import CustomPage from '@/app/components/custom/custom-page'
import {
@@ -60,14 +76,14 @@ export default function AccountSetting({
{
key: ACCOUNT_SETTING_TAB.PROVIDER,
name: t('settings.provider', { ns: 'common' }),
icon: <span className={cn('i-ri-brain-2-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-brain-2-fill', iconClassName)} />,
icon: <RiBrain2Line className={iconClassName} />,
activeIcon: <RiBrain2Fill className={iconClassName} />,
},
{
key: ACCOUNT_SETTING_TAB.MEMBERS,
name: t('settings.members', { ns: 'common' }),
icon: <span className={cn('i-ri-group-2-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-group-2-fill', iconClassName)} />,
icon: <RiGroup2Line className={iconClassName} />,
activeIcon: <RiGroup2Fill className={iconClassName} />,
},
]
@@ -76,8 +92,8 @@ export default function AccountSetting({
key: ACCOUNT_SETTING_TAB.BILLING,
name: t('settings.billing', { ns: 'common' }),
description: t('plansCommon.receiptInfo', { ns: 'billing' }),
icon: <span className={cn('i-ri-money-dollar-circle-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-money-dollar-circle-fill', iconClassName)} />,
icon: <RiMoneyDollarCircleLine className={iconClassName} />,
activeIcon: <RiMoneyDollarCircleFill className={iconClassName} />,
})
}
@@ -85,14 +101,14 @@ export default function AccountSetting({
{
key: ACCOUNT_SETTING_TAB.DATA_SOURCE,
name: t('settings.dataSource', { ns: 'common' }),
icon: <span className={cn('i-ri-database-2-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-database-2-fill', iconClassName)} />,
icon: <RiDatabase2Line className={iconClassName} />,
activeIcon: <RiDatabase2Fill className={iconClassName} />,
},
{
key: ACCOUNT_SETTING_TAB.API_BASED_EXTENSION,
name: t('settings.apiBasedExtension', { ns: 'common' }),
icon: <span className={cn('i-ri-puzzle-2-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-puzzle-2-fill', iconClassName)} />,
icon: <RiPuzzle2Line className={iconClassName} />,
activeIcon: <RiPuzzle2Fill className={iconClassName} />,
},
)
@@ -100,8 +116,8 @@ export default function AccountSetting({
items.push({
key: ACCOUNT_SETTING_TAB.CUSTOM,
name: t('custom', { ns: 'custom' }),
icon: <span className={cn('i-ri-color-filter-line', iconClassName)} />,
activeIcon: <span className={cn('i-ri-color-filter-fill', iconClassName)} />,
icon: <RiColorFilterLine className={iconClassName} />,
activeIcon: <RiColorFilterFill className={iconClassName} />,
})
}
@@ -124,8 +140,8 @@ export default function AccountSetting({
{
key: ACCOUNT_SETTING_TAB.LANGUAGE,
name: t('settings.language', { ns: 'common' }),
icon: <span className={cn('i-ri-translate-2', iconClassName)} />,
activeIcon: <span className={cn('i-ri-translate-2', iconClassName)} />,
icon: <RiTranslate2 className={iconClassName} />,
activeIcon: <RiTranslate2 className={iconClassName} />,
},
],
},
@@ -155,13 +171,13 @@ export default function AccountSetting({
>
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
<div className="mb-8 mt-6 px-3 py-2 text-text-primary title-2xl-semi-bold">{t('userProfile.settings', { ns: 'common' })}</div>
<div className="title-2xl-semi-bold mb-8 mt-6 px-3 py-2 text-text-primary">{t('userProfile.settings', { ns: 'common' })}</div>
<div className="w-full">
{
menuItems.map(menuItem => (
<div key={menuItem.key} className="mb-2">
{!isCurrentWorkspaceDatasetOperator && (
<div className="mb-0.5 py-2 pb-1 pl-3 text-text-tertiary system-xs-medium-uppercase">{menuItem.name}</div>
<div className="system-xs-medium-uppercase mb-0.5 py-2 pb-1 pl-3 text-text-tertiary">{menuItem.name}</div>
)}
<div>
{
@@ -170,7 +186,7 @@ export default function AccountSetting({
key={item.key}
className={cn(
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
activeMenu === item.key ? 'system-sm-semibold bg-state-base-active text-components-menu-item-text-active' : 'system-sm-medium text-components-menu-item-text',
)}
title={item.name}
onClick={() => {
@@ -197,36 +213,38 @@ export default function AccountSetting({
className="px-2"
onClick={onCancel}
>
<span className="i-ri-close-line h-5 w-5" />
<RiCloseLine className="h-5 w-5" />
</Button>
<div className="mt-1 text-text-tertiary system-2xs-medium-uppercase">ESC</div>
<div className="system-2xs-medium-uppercase mt-1 text-text-tertiary">ESC</div>
</div>
<div ref={scrollRef} className="w-full overflow-y-auto bg-components-panel-bg pb-4">
<div className={cn('sticky top-0 z-20 mx-8 mb-[18px] flex items-center bg-components-panel-bg pb-2 pt-[27px]', scrolled && 'border-b border-divider-regular')}>
<div className="shrink-0 text-text-primary title-2xl-semi-bold">
<div className="title-2xl-semi-bold shrink-0 text-text-primary">
{activeItem?.name}
{activeItem?.description && (
<div className="mt-1 text-text-tertiary system-sm-regular">{activeItem?.description}</div>
<div className="system-sm-regular mt-1 text-text-tertiary">{activeItem?.description}</div>
)}
</div>
{activeItem?.key === ACCOUNT_SETTING_TAB.PROVIDER && (
{activeItem?.key === 'provider' && (
<div className="flex grow justify-end">
<SearchInput
className="w-[200px]"
onChange={setSearchValue}
<Input
showLeftIcon
wrapperClassName="!w-[200px]"
className="!h-8 !text-[13px]"
onChange={e => setSearchValue(e.target.value)}
value={searchValue}
/>
</div>
)}
</div>
<div className="px-4 pt-2 sm:px-8">
{activeMenu === ACCOUNT_SETTING_TAB.PROVIDER && <ModelProviderPage searchText={searchValue} />}
{activeMenu === ACCOUNT_SETTING_TAB.MEMBERS && <MembersPage />}
{activeMenu === ACCOUNT_SETTING_TAB.BILLING && <BillingPage />}
{activeMenu === ACCOUNT_SETTING_TAB.DATA_SOURCE && <DataSourcePage />}
{activeMenu === ACCOUNT_SETTING_TAB.API_BASED_EXTENSION && <ApiBasedExtensionPage />}
{activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && <CustomPage />}
{activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && <LanguagePage />}
{activeMenu === 'provider' && <ModelProviderPage searchText={searchValue} />}
{activeMenu === 'members' && <MembersPage />}
{activeMenu === 'billing' && <BillingPage />}
{activeMenu === 'data-source' && <DataSourcePage />}
{activeMenu === 'api-based-extension' && <ApiBasedExtensionPage />}
{activeMenu === 'custom' && <CustomPage />}
{activeMenu === 'language' && <LanguagePage />}
</div>
</div>
</div>

View File

@@ -46,7 +46,8 @@ const nodeDefault: NodeDefault<HttpNodeType> = {
if (!errorMessages
&& payload.body.type === BodyType.binary
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)) {
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)
) {
errorMessages = t('errorMsg.fieldRequired', { ns: 'workflow', field: t('nodes.http.binaryFileVariable', { ns: 'workflow' }) })
}

View File

@@ -2052,6 +2052,9 @@
"app/components/base/input/index.tsx": {
"react-refresh/only-export-components": {
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"app/components/base/linked-apps-panel/index.tsx": {
@@ -3991,6 +3994,9 @@
"app/components/header/account-setting/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
},
"tailwindcss/enforce-consistent-class-order": {
"count": 7
}
},
"app/components/header/account-setting/key-validator/declarations.ts": {
@@ -8163,6 +8169,11 @@
"count": 3
}
},
"i18n-config/README.md": {
"no-irregular-whitespace": {
"count": 1
}
},
"i18n/de-DE/billing.json": {
"no-irregular-whitespace": {
"count": 1

View File

@@ -6,7 +6,7 @@ This directory contains i18n tooling and configuration. Translation files live u
## File Structure
```txt
```
web/i18n
├── en-US
│ ├── app.json
@@ -36,7 +36,7 @@ By default we will use `LanguagesSupported` to determine which languages are sup
1. Create a new folder for the new language.
```txt
```
cd web/i18n
cp -r en-US id-ID
```
@@ -98,7 +98,7 @@ export const languages = [
{
value: 'ru-RU',
name: 'Русский(Россия)',
example: 'Привет, Dify!',
example: ' Привет, Dify!',
supported: false,
},
{

View File

@@ -3,7 +3,7 @@
"type": "module",
"version": "1.13.0",
"private": true,
"packageManager": "pnpm@10.27.0",
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": {
"#i18n": {
"react-server": "./i18n-config/lib.server.ts",
@@ -165,10 +165,10 @@
"zustand": "5.0.9"
},
"devDependencies": {
"@antfu/eslint-config": "7.6.1",
"@chromatic-com/storybook": "5.0.1",
"@antfu/eslint-config": "7.2.0",
"@chromatic-com/storybook": "5.0.0",
"@egoist/tailwindcss-icons": "1.9.2",
"@eslint-react/eslint-plugin": "2.13.0",
"@eslint-react/eslint-plugin": "2.9.4",
"@iconify-json/heroicons": "1.2.3",
"@iconify-json/ri": "1.2.9",
"@mdx-js/loader": "3.1.1",
@@ -177,12 +177,12 @@
"@next/mdx": "16.1.5",
"@rgrove/parse-xml": "4.2.0",
"@serwist/turbopack": "9.5.4",
"@storybook/addon-docs": "10.2.13",
"@storybook/addon-links": "10.2.13",
"@storybook/addon-onboarding": "10.2.13",
"@storybook/addon-themes": "10.2.13",
"@storybook/nextjs-vite": "10.2.13",
"@storybook/react": "10.2.13",
"@storybook/addon-docs": "10.2.0",
"@storybook/addon-links": "10.2.0",
"@storybook/addon-onboarding": "10.2.0",
"@storybook/addon-themes": "10.2.0",
"@storybook/nextjs-vite": "10.2.0",
"@storybook/react": "10.2.0",
"@tanstack/eslint-plugin-query": "5.91.4",
"@tanstack/react-devtools": "0.9.2",
"@tanstack/react-form-devtools": "0.2.12",
@@ -208,7 +208,7 @@
"@types/semver": "7.7.1",
"@types/sortablejs": "1.15.8",
"@types/uuid": "10.0.0",
"@typescript-eslint/parser": "8.56.1",
"@typescript-eslint/parser": "8.54.0",
"@typescript/native-preview": "7.0.0-dev.20251209.1",
"@vitejs/plugin-react": "5.1.2",
"@vitest/coverage-v8": "4.0.17",
@@ -216,13 +216,13 @@
"code-inspector-plugin": "1.3.6",
"cross-env": "10.1.0",
"esbuild": "0.27.2",
"eslint": "10.0.2",
"eslint-plugin-better-tailwindcss": "4.3.1",
"eslint-plugin-hyoban": "0.11.2",
"eslint": "9.39.2",
"eslint-plugin-better-tailwindcss": "https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@c0161c7",
"eslint-plugin-hyoban": "0.11.1",
"eslint-plugin-react-hooks": "7.0.1",
"eslint-plugin-react-refresh": "0.5.2",
"eslint-plugin-sonarjs": "4.0.0",
"eslint-plugin-storybook": "10.2.13",
"eslint-plugin-react-refresh": "0.5.0",
"eslint-plugin-sonarjs": "3.0.6",
"eslint-plugin-storybook": "10.2.6",
"husky": "9.1.7",
"iconify-import-svg": "0.1.1",
"jsdom": "27.3.0",
@@ -235,7 +235,7 @@
"react-scan": "0.4.3",
"sass": "1.93.2",
"serwist": "9.5.4",
"storybook": "10.2.13",
"storybook": "10.2.0",
"tailwindcss": "3.4.19",
"tsx": "4.21.0",
"typescript": "5.9.3",
@@ -249,7 +249,6 @@
"overrides": {
"@monaco-editor/loader": "1.5.0",
"@nolyfill/safe-buffer": "npm:safe-buffer@^5.2.1",
"@stylistic/eslint-plugin": "https://pkg.pr.new/@stylistic/eslint-plugin@258f9d8",
"array-includes": "npm:@nolyfill/array-includes@^1",
"array.prototype.findlast": "npm:@nolyfill/array.prototype.findlast@^1",
"array.prototype.findlastindex": "npm:@nolyfill/array.prototype.findlastindex@^1",

1883
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
post<CommonResponse & { data: string }>('/account/change-email', { body })
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>