mirror of
https://github.com/langgenius/dify.git
synced 2026-02-28 12:25:14 +00:00
Compare commits
6 Commits
main
...
verify-ema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d410d7fe | ||
|
|
7ea09fe49d | ||
|
|
176bab84a9 | ||
|
|
ebe7548c72 | ||
|
|
fce50cf12c | ||
|
|
9a71d70738 |
@@ -1,168 +0,0 @@
|
||||
---
|
||||
name: backend-code-review
|
||||
description: Review backend code for quality, security, maintainability, and best practices based on established checklist rules. Use when the user requests a review, analysis, or improvement of backend files (e.g., `.py`) under the `api/` directory. Do NOT use for frontend files (e.g., `.tsx`, `.ts`, `.js`). Supports pending-change review, code snippets review, and file-focused review.
|
||||
---
|
||||
|
||||
# Backend Code Review
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill whenever the user asks to **review, analyze, or improve** backend code (e.g., `.py`) under the `api/` directory. Supports the following review modes:
|
||||
|
||||
- **Pending-change review**: when the user asks to review current changes (inspect staged/working-tree files slated for commit to get the changes).
|
||||
- **Code snippets review**: when the user pastes code snippets (e.g., a function/class/module excerpt) into the chat and asks for a review.
|
||||
- **File-focused review**: when the user points to specific files and asks for a review of those files (one file or a small, explicit set of files, e.g., `api/...`, `api/app.py`).
|
||||
|
||||
Do NOT use this skill when:
|
||||
|
||||
- The request is about frontend code or UI (e.g., `.tsx`, `.ts`, `.js`, `web/`).
|
||||
- The user is not asking for a review/analysis/improvement of backend code.
|
||||
- The scope is not under `api/` (unless the user explicitly asks to review backend-related changes outside `api/`).
|
||||
|
||||
## How to use this skill
|
||||
|
||||
Follow these steps when using this skill:
|
||||
|
||||
1. **Identify the review mode** (pending-change vs snippet vs file-focused) based on the user’s input. Keep the scope tight: review only what the user provided or explicitly referenced.
|
||||
2. Follow the rules defined in **Checklist** to perform the review. If no Checklist rule matches, apply **General Review Rules** as a fallback to perform the best-effort review.
|
||||
3. Compose the final output strictly follow the **Required Output Format**.
|
||||
|
||||
Notes when using this skill:
|
||||
- Always include actionable fixes or suggestions (including possible code snippets).
|
||||
- Use best-effort `File:Line` references when a file path and line numbers are available; otherwise, use the most specific identifier you can.
|
||||
|
||||
## Checklist
|
||||
|
||||
- db schema design: if the review scope includes code/files under `api/models/` or `api/migrations/`, follow [references/db-schema-rule.md](references/db-schema-rule.md) to perform the review
|
||||
- architecture: if the review scope involves controller/service/core-domain/libs/model layering, dependency direction, or moving responsibilities across modules, follow [references/architecture-rule.md](references/architecture-rule.md) to perform the review
|
||||
- repositories abstraction: if the review scope contains table/model operations (e.g., `select(...)`, `session.execute(...)`, joins, CRUD) and is not under `api/repositories`, `api/core/repositories`, or `api/extensions/*/repositories/`, follow [references/repositories-rule.md](references/repositories-rule.md) to perform the review
|
||||
- sqlalchemy patterns: if the review scope involves SQLAlchemy session/query usage, db transaction/crud usage, or raw SQL usage, follow [references/sqlalchemy-rule.md](references/sqlalchemy-rule.md) to perform the review
|
||||
|
||||
## General Review Rules
|
||||
|
||||
### 1. Security Review
|
||||
|
||||
Check for:
|
||||
- SQL injection vulnerabilities
|
||||
- Server-Side Request Forgery (SSRF)
|
||||
- Command injection
|
||||
- Insecure deserialization
|
||||
- Hardcoded secrets/credentials
|
||||
- Improper authentication/authorization
|
||||
- Insecure direct object references
|
||||
|
||||
### 2. Performance Review
|
||||
|
||||
Check for:
|
||||
- N+1 queries
|
||||
- Missing database indexes
|
||||
- Memory leaks
|
||||
- Blocking operations in async code
|
||||
- Missing caching opportunities
|
||||
|
||||
### 3. Code Quality Review
|
||||
|
||||
Check for:
|
||||
- Code forward compatibility
|
||||
- Code duplication (DRY violations)
|
||||
- Functions doing too much (SRP violations)
|
||||
- Deep nesting / complex conditionals
|
||||
- Magic numbers/strings
|
||||
- Poor naming
|
||||
- Missing error handling
|
||||
- Incomplete type coverage
|
||||
|
||||
### 4. Testing Review
|
||||
|
||||
Check for:
|
||||
- Missing test coverage for new code
|
||||
- Tests that don't test behavior
|
||||
- Flaky test patterns
|
||||
- Missing edge cases
|
||||
|
||||
## Required Output Format
|
||||
|
||||
When this skill invoked, the response must exactly follow one of the two templates:
|
||||
|
||||
### Template A (any findings)
|
||||
|
||||
```markdown
|
||||
# Code Review Summary
|
||||
|
||||
Found <X> critical issues need to be fixed:
|
||||
|
||||
## 🔴 Critical (Must Fix)
|
||||
|
||||
### 1. <brief description of the issue>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the issue>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each critical issue) ...
|
||||
|
||||
Found <Y> suggestions for improvement:
|
||||
|
||||
## 🟡 Suggestions (Should Consider)
|
||||
|
||||
### 1. <brief description of the suggestion>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<detailed explanation and references of the suggestion>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
1. <brief description of suggested fix>
|
||||
2. <code example> (optional, omit if not applicable)
|
||||
|
||||
---
|
||||
... (repeat for each suggestion) ...
|
||||
|
||||
Found <Z> optional nits:
|
||||
|
||||
## 🟢 Nits (Optional)
|
||||
### 1. <brief description of the nit>
|
||||
|
||||
FilePath: <path> line <line>
|
||||
<relevant code snippet or pointer>
|
||||
|
||||
#### Explanation
|
||||
|
||||
<explanation and references of the optional nit>
|
||||
|
||||
#### Suggested Fix
|
||||
|
||||
- <minor suggestions>
|
||||
|
||||
---
|
||||
... (repeat for each nits) ...
|
||||
|
||||
## ✅ What's Good
|
||||
|
||||
- <Positive feedback on good patterns>
|
||||
```
|
||||
|
||||
- If there are no critical issues or suggestions or option nits or good points, just omit that section.
|
||||
- If the issue number is more than 10, summarize as "Found 10+ critical issues/suggestions/optional nits" and only output the first 10 items.
|
||||
- Don't compress the blank lines between sections; keep them as-is for readability.
|
||||
- If there is any issue requires code changes, append a brief follow-up question to ask whether the user wants to apply the fix(es) after the structured output. For example: "Would you like me to use the Suggested fix(es) to address these issues?"
|
||||
|
||||
### Template B (no issues)
|
||||
|
||||
```markdown
|
||||
## Code Review Summary
|
||||
✅ No issues found.
|
||||
```
|
||||
@@ -1,91 +0,0 @@
|
||||
# Rule Catalog — Architecture
|
||||
|
||||
## Scope
|
||||
- Covers: controller/service/core-domain/libs/model layering, dependency direction, responsibility placement, observability-friendly flow.
|
||||
|
||||
## Rules
|
||||
|
||||
### Keep business logic out of controllers
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Controllers should parse input, call services, and return serialized responses. Business decisions inside controllers make behavior hard to reuse and test.
|
||||
- Suggested fix: Move domain/business logic into the service or core/domain layer. Keep controller handlers thin and orchestration-focused.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = request.get_json() or {}
|
||||
if payload.get("force") and current_user.role != "admin":
|
||||
raise ValueError("only admin can force publish")
|
||||
app = App.query.get(app_id)
|
||||
app.status = "published"
|
||||
db.session.commit()
|
||||
return {"result": "ok"}
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
@bp.post("/apps/<app_id>/publish")
|
||||
def publish_app(app_id: str):
|
||||
payload = PublishRequest.model_validate(request.get_json() or {})
|
||||
app_service.publish_app(app_id=app_id, force=payload.force, actor_id=current_user.id)
|
||||
return {"result": "ok"}
|
||||
```
|
||||
|
||||
### Preserve layer dependency direction
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Controllers may depend on services, and services may depend on core/domain abstractions. Reversing this direction (for example, core importing controller/web modules) creates cycles and leaks transport concerns into domain code.
|
||||
- Suggested fix: Extract shared contracts into core/domain or service-level modules and make upper layers depend on lower, not the reverse.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
from controllers.console.app import request_context
|
||||
|
||||
def can_publish() -> bool:
|
||||
return request_context.current_user.is_admin
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# core/policy/publish_policy.py
|
||||
def can_publish(role: str) -> bool:
|
||||
return role == "admin"
|
||||
|
||||
# service layer adapts web/user context to domain input
|
||||
allowed = can_publish(role=current_user.role)
|
||||
```
|
||||
|
||||
### Keep libs business-agnostic
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Modules under `api/libs/` should remain reusable, business-agnostic building blocks. They must not encode product/domain-specific rules, workflow orchestration, or business decisions.
|
||||
- Suggested fix:
|
||||
- If business logic appears in `api/libs/`, extract it into the appropriate `services/` or `core/` module and keep `libs` focused on generic, cross-cutting helpers.
|
||||
- Keep `libs` dependencies clean: avoid importing service/controller/domain-specific modules into `api/libs/`.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# api/libs/conversation_filter.py
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
# Domain policy and service dependency are leaking into libs.
|
||||
service = ConversationService()
|
||||
if service.has_paid_plan(tenant_id):
|
||||
return conversation.idle_days > 90
|
||||
return conversation.idle_days > 30
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# api/libs/datetime_utils.py (business-agnostic helper)
|
||||
def older_than_days(idle_days: int, threshold_days: int) -> bool:
|
||||
return idle_days > threshold_days
|
||||
|
||||
# services/conversation_service.py (business logic stays in service/core)
|
||||
from libs.datetime_utils import older_than_days
|
||||
|
||||
def should_archive_conversation(conversation, tenant_id: str) -> bool:
|
||||
threshold_days = 90 if has_paid_plan(tenant_id) else 30
|
||||
return older_than_days(conversation.idle_days, threshold_days)
|
||||
```
|
||||
@@ -1,157 +0,0 @@
|
||||
# Rule Catalog — DB Schema Design
|
||||
|
||||
## Scope
|
||||
- Covers: model/base inheritance, schema boundaries in model properties, tenant-aware schema design, index redundancy checks, dialect portability in models, and cross-database compatibility in migrations.
|
||||
- Does NOT cover: session lifecycle, transaction boundaries, and query execution patterns (handled by `sqlalchemy-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Do not query other tables inside `@property`
|
||||
- Category: [maintainability, performance]
|
||||
- Severity: critical
|
||||
- Description: A model `@property` must not open sessions or query other tables. This hides dependencies across models, tightly couples schema objects to data access, and can cause N+1 query explosions when iterating collections.
|
||||
- Suggested fix:
|
||||
- Keep model properties pure and local to already-loaded fields.
|
||||
- Move cross-table data fetching to service/repository methods.
|
||||
- For list/batch reads, fetch required related data explicitly (join/preload/bulk query) before rendering derived values.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def app_name(self) -> str:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app = session.execute(select(App).where(App.id == self.app_id)).scalar_one()
|
||||
return app.name
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
class Conversation(TypeBase):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
@property
|
||||
def display_title(self) -> str:
|
||||
return self.name or "Untitled"
|
||||
|
||||
|
||||
# Service/repository layer performs explicit batch fetch for related App rows.
|
||||
```
|
||||
|
||||
### Prefer including `tenant_id` in model definitions
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: In multi-tenant domains, include `tenant_id` in schema definitions whenever the entity belongs to tenant-owned data. This improves data isolation safety and keeps future partitioning/sharding strategies practical as data volume grows.
|
||||
- Suggested fix:
|
||||
- Add a `tenant_id` column and ensure related unique/index constraints include tenant dimension when applicable.
|
||||
- Propagate `tenant_id` through service/repository contracts to keep access paths tenant-aware.
|
||||
- Exception: if a table is explicitly designed as non-tenant-scoped global metadata, document that design decision clearly.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class Dataset(TypeBase):
|
||||
__tablename__ = "datasets"
|
||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
||||
name: Mapped[str] = mapped_column(sa.String(255), nullable=False)
|
||||
```
|
||||
|
||||
### Detect and avoid duplicate/redundant indexes
|
||||
- Category: performance
|
||||
- Severity: suggestion
|
||||
- Description: Review index definitions for leftmost-prefix redundancy. For example, index `(a, b, c)` can safely cover most lookups for `(a, b)`. Keeping both may increase write overhead and can mislead the optimizer into suboptimal execution plans.
|
||||
- Suggested fix:
|
||||
- Before adding an index, compare against existing composite indexes by leftmost-prefix rules.
|
||||
- Drop or avoid creating redundant prefixes unless there is a proven query-pattern need.
|
||||
- Apply the same review standard in both model `__table_args__` and migration index DDL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
__table_args__ = (
|
||||
sa.Index("idx_msg_tenant_app", "tenant_id", "app_id"),
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
__table_args__ = (
|
||||
# Keep the wider index unless profiling proves a dedicated short index is needed.
|
||||
sa.Index("idx_msg_tenant_app_created", "tenant_id", "app_id", "created_at"),
|
||||
)
|
||||
```
|
||||
|
||||
### Avoid PostgreSQL-only dialect usage in models; wrap in `models.types`
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Model/schema definitions should avoid PostgreSQL-only constructs directly in business models. When database-specific behavior is required, encapsulate it in `api/models/types.py` using both PostgreSQL and MySQL dialect implementations, then consume that abstraction from model code.
|
||||
- Suggested fix:
|
||||
- Do not directly place dialect-only types/operators in model columns when a portable wrapper can be used.
|
||||
- Add or extend wrappers in `models.types` (for example, `AdjustedJSON`, `LongText`, `BinaryData`) to normalize behavior across PostgreSQL and MySQL.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from models.types import AdjustedJSON
|
||||
|
||||
class ToolConfig(TypeBase):
|
||||
__tablename__ = "tool_configs"
|
||||
config: Mapped[dict] = mapped_column(AdjustedJSON(), nullable=False)
|
||||
```
|
||||
|
||||
### Guard migration incompatibilities with dialect checks and shared types
|
||||
- Category: maintainability
|
||||
- Severity: critical
|
||||
- Description: Migration scripts under `api/migrations/versions/` must account for PostgreSQL/MySQL incompatibilities explicitly. For dialect-sensitive DDL or defaults, branch on the active dialect (for example, `conn.dialect.name == "postgresql"`), and prefer reusable compatibility abstractions from `models.types` where applicable.
|
||||
- Suggested fix:
|
||||
- In migration upgrades/downgrades, bind connection and branch by dialect for incompatible SQL fragments.
|
||||
- Reuse `models.types` wrappers in column definitions when that keeps behavior aligned with runtime models.
|
||||
- Avoid one-dialect-only migration logic unless there is a documented, deliberate compatibility exception.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"data_source_type",
|
||||
sa.String(255),
|
||||
server_default=sa.text("'database'::character varying"),
|
||||
nullable=False,
|
||||
)
|
||||
)
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
def _is_pg(conn) -> bool:
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
conn = op.get_bind()
|
||||
default_expr = sa.text("'database'::character varying") if _is_pg(conn) else sa.text("'database'")
|
||||
|
||||
with op.batch_alter_table("dataset_keyword_tables") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("data_source_type", sa.String(255), server_default=default_expr, nullable=False)
|
||||
)
|
||||
```
|
||||
@@ -1,61 +0,0 @@
|
||||
# Rule Catalog - Repositories Abstraction
|
||||
|
||||
## Scope
|
||||
- Covers: when to reuse existing repository abstractions, when to introduce new repositories, and how to preserve dependency direction between service/core and infrastructure implementations.
|
||||
- Does NOT cover: SQLAlchemy session lifecycle and query-shape specifics (handled by `sqlalchemy-rule.md`), and table schema/migration design (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Introduce repositories abstraction
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: If a table/model already has a repository abstraction, all reads/writes/queries for that table should use the existing repository. If no repository exists, introduce one only when complexity justifies it, such as large/high-volume tables, repeated complex query logic, or likely storage-strategy variation.
|
||||
- Suggested fix:
|
||||
- First check `api/repositories`, `api/core/repositories`, and `api/extensions/*/repositories/` to verify whether the table/model already has a repository abstraction. If it exists, route all operations through it and add missing repository methods instead of bypassing it with ad-hoc SQLAlchemy access.
|
||||
- If no repository exists, add one only when complexity warrants it (for example, repeated complex queries, large data domains, or multiple storage strategies), while preserving dependency direction (service/core depends on abstraction; infra provides implementation).
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Existing repository is ignored and service uses ad-hoc table queries.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.session.execute(
|
||||
select(App).where(App.id == app_id, App.tenant_id == tenant_id)
|
||||
).scalar_one()
|
||||
app.archived = True
|
||||
self.session.commit()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case A: Existing repository must be reused for all table operations.
|
||||
class AppService:
|
||||
def archive_app(self, app_id: str, tenant_id: str) -> None:
|
||||
app = self.app_repo.get_by_id(app_id=app_id, tenant_id=tenant_id)
|
||||
app.archived = True
|
||||
self.app_repo.save(app)
|
||||
|
||||
# If the query is missing, extend the existing abstraction.
|
||||
active_apps = self.app_repo.list_active_for_tenant(tenant_id=tenant_id)
|
||||
```
|
||||
- Bad:
|
||||
```python
|
||||
# No repository exists, but large-domain query logic is scattered in service code.
|
||||
class ConversationService:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
# many filters/joins/pagination variants duplicated across services
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Case B: Introduce repository for large/complex domains or storage variation.
|
||||
class ConversationRepository(Protocol):
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]: ...
|
||||
|
||||
class SqlAlchemyConversationRepository:
|
||||
def list_recent_for_app(self, app_id: str, tenant_id: str, limit: int) -> list[Conversation]:
|
||||
...
|
||||
|
||||
class ConversationService:
|
||||
def __init__(self, conversation_repo: ConversationRepository):
|
||||
self.conversation_repo = conversation_repo
|
||||
```
|
||||
@@ -1,139 +0,0 @@
|
||||
# Rule Catalog — SQLAlchemy Patterns
|
||||
|
||||
## Scope
|
||||
- Covers: SQLAlchemy session and transaction lifecycle, query construction, tenant scoping, raw SQL boundaries, and write-path concurrency safeguards.
|
||||
- Does NOT cover: table/model schema and migration design details (handled by `db-schema-rule.md`).
|
||||
|
||||
## Rules
|
||||
|
||||
### Use Session context manager with explicit transaction control behavior
|
||||
- Category: best practices
|
||||
- Severity: critical
|
||||
- Description: Session and transaction lifecycle must be explicit and bounded on write paths. Missing commits can silently drop intended updates, while ad-hoc or long-lived transactions increase contention, lock duration, and deadlock risk.
|
||||
- Suggested fix:
|
||||
- Use **explicit `session.commit()`** after completing a related write unit.
|
||||
- Or use **`session.begin()` context manager** for automatic commit/rollback on a scoped block.
|
||||
- Keep transaction windows short: avoid network I/O, heavy computation, or unrelated work inside the transaction.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# Missing commit: write may never be persisted.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Long transaction: external I/O inside a DB transaction.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
call_external_api()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# Option 1: explicit commit.
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
|
||||
# Option 2: scoped transaction with automatic commit/rollback.
|
||||
with Session(db.engine, expire_on_commit=False) as session, session.begin():
|
||||
run = session.get(WorkflowRun, run_id)
|
||||
run.status = "cancelled"
|
||||
|
||||
# Keep non-DB work outside transaction scope.
|
||||
call_external_api()
|
||||
```
|
||||
|
||||
### Enforce tenant_id scoping on shared-resource queries
|
||||
- Category: security
|
||||
- Severity: critical
|
||||
- Description: Reads and writes against shared tables must be scoped by `tenant_id` to prevent cross-tenant data leakage or corruption.
|
||||
- Suggested fix: Add `tenant_id` predicate to all tenant-owned entity queries and propagate tenant context through service/repository interfaces.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
workflow = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Prefer SQLAlchemy expressions over raw SQL by default
|
||||
- Category: maintainability
|
||||
- Severity: suggestion
|
||||
- Description: Raw SQL should be exceptional. ORM/Core expressions are easier to evolve, safer to compose, and more consistent with the codebase.
|
||||
- Suggested fix: Rewrite straightforward raw SQL into SQLAlchemy `select/update/delete` expressions; keep raw SQL only when required by clear technical constraints.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
row = session.execute(
|
||||
text("SELECT * FROM workflows WHERE id = :id AND tenant_id = :tenant_id"),
|
||||
{"id": workflow_id, "tenant_id": tenant_id},
|
||||
).first()
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.id == workflow_id,
|
||||
Workflow.tenant_id == tenant_id,
|
||||
)
|
||||
row = session.execute(stmt).scalar_one_or_none()
|
||||
```
|
||||
|
||||
### Protect write paths with concurrency safeguards
|
||||
- Category: quality
|
||||
- Severity: critical
|
||||
- Description: Multi-writer paths without explicit concurrency control can silently overwrite data. Choose the safeguard based on contention level, lock scope, and throughput cost instead of defaulting to one strategy.
|
||||
- Suggested fix:
|
||||
- **Optimistic locking**: Use when contention is usually low and retries are acceptable. Add a version (or updated_at) guard in `WHERE` and treat `rowcount == 0` as a conflict.
|
||||
- **Redis distributed lock**: Use when the critical section spans multiple steps/processes (or includes non-DB side effects) and you need cross-worker mutual exclusion.
|
||||
- **SELECT ... FOR UPDATE**: Use when contention is high on the same rows and strict in-transaction serialization is required. Keep transactions short to reduce lock wait/deadlock risk.
|
||||
- In all cases, scope by `tenant_id` and verify affected row counts for conditional writes.
|
||||
- Example:
|
||||
- Bad:
|
||||
```python
|
||||
# No tenant scope, no conflict detection, and no lock on a contested write path.
|
||||
session.execute(update(WorkflowRun).where(WorkflowRun.id == run_id).values(status="cancelled"))
|
||||
session.commit() # silently overwrites concurrent updates
|
||||
```
|
||||
- Good:
|
||||
```python
|
||||
# 1) Optimistic lock (low contention, retry on conflict)
|
||||
result = session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.id == run_id,
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.version == expected_version,
|
||||
)
|
||||
.values(status="cancelled", version=WorkflowRun.version + 1)
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
raise WorkflowStateConflictError("stale version, retry")
|
||||
|
||||
# 2) Redis distributed lock (cross-worker critical section)
|
||||
lock_name = f"workflow_run_lock:{tenant_id}:{run_id}"
|
||||
with redis_client.lock(lock_name, timeout=20):
|
||||
session.execute(
|
||||
update(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.values(status="cancelled")
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# 3) Pessimistic lock with SELECT ... FOR UPDATE (high contention)
|
||||
run = session.execute(
|
||||
select(WorkflowRun)
|
||||
.where(WorkflowRun.id == run_id, WorkflowRun.tenant_id == tenant_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
run.status = "cancelled"
|
||||
session.commit()
|
||||
```
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/backend-code-review
|
||||
9
.github/workflows/pyrefly-diff-comment.yml
vendored
9
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -77,7 +77,14 @@ jobs:
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
? `### Pyrefly Diff
|
||||
<details>
|
||||
<summary>base → PR</summary>
|
||||
|
||||
\`\`\`diff
|
||||
${diff}
|
||||
\`\`\`
|
||||
</details>`
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
18
.github/workflows/pyrefly-diff.yml
vendored
18
.github/workflows/pyrefly-diff.yml
vendored
@@ -74,16 +74,14 @@ jobs:
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? [
|
||||
'### Pyrefly Diff',
|
||||
'<details>',
|
||||
'<summary>base → PR</summary>',
|
||||
'',
|
||||
'```diff',
|
||||
diff,
|
||||
'```',
|
||||
'</details>',
|
||||
].join('\n')
|
||||
? `### Pyrefly Diff
|
||||
<details>
|
||||
<summary>base → PR</summary>
|
||||
|
||||
\`\`\`diff
|
||||
${diff}
|
||||
\`\`\`
|
||||
</details>`
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
63
.github/workflows/web-tests.yml
vendored
63
.github/workflows/web-tests.yml
vendored
@@ -3,22 +3,14 @@ name: Web Tests
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
name: Web Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -47,58 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
include-hidden-files: true
|
||||
retention-days: 1
|
||||
|
||||
merge-reports:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: pnpm test:ci
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -50,6 +50,7 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
@@ -105,6 +106,9 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
@@ -142,6 +146,8 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@@ -154,6 +160,7 @@ ignore_imports =
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
@@ -190,6 +197,7 @@ ignore_imports =
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
@@ -232,6 +240,7 @@ ignore_imports =
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
|
||||
@@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
@@ -54,6 +54,86 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
# Or on macOS
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
@@ -133,7 +133,6 @@ class EducationAutocompleteQuery(BaseModel):
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
|
||||
@@ -547,13 +546,17 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.token is not None:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@@ -573,7 +576,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -608,12 +611,26 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
phase_transitions: dict[str, str] = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@@ -643,13 +660,22 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@@ -669,14 +669,16 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle retriever resources events."""
|
||||
self._message_cycle_manager.handle_retriever_resources(event)
|
||||
yield from ()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
def _handle_annotation_reply_event(
|
||||
self, event: QueueAnnotationReplyEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle annotation reply events."""
|
||||
self._message_cycle_manager.handle_annotation_reply(event)
|
||||
yield from ()
|
||||
return
|
||||
yield # Make this a generator
|
||||
|
||||
def _handle_message_replace_event(
|
||||
self, event: QueueMessageReplaceEvent, **kwargs
|
||||
|
||||
@@ -122,7 +122,7 @@ class AppQueueManager(ABC):
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
@@ -19,7 +18,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
@@ -180,15 +178,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
|
||||
@@ -1,39 +1,16 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from threading import Lock
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,238 +103,3 @@ class DatasourceManager:
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
|
||||
datasource_runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
return datasource_runtime.get_icon_url(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def stream_online_results(
|
||||
cls,
|
||||
*,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[DatasourceMessage, None, Any]:
|
||||
"""
|
||||
Pull-based streaming of domain messages from datasource plugins.
|
||||
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
|
||||
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
|
||||
"""
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=ds_type,
|
||||
)
|
||||
|
||||
dsp_service = DatasourceProviderService()
|
||||
credentials = dsp_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
doc_runtime.runtime.credentials = credentials
|
||||
if datasource_param is None:
|
||||
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
|
||||
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
|
||||
user_id=user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_param.workspace_id,
|
||||
page_id=datasource_param.page_id,
|
||||
type=datasource_param.type,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
drive_runtime.runtime.credentials = credentials
|
||||
if online_drive_request is None:
|
||||
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
|
||||
inner_gen = drive_runtime.online_drive_download_file(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=online_drive_request.id,
|
||||
bucket=online_drive_request.bucket,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
|
||||
|
||||
# Bridge through to caller while preserving generator return contract
|
||||
yield from inner_gen
|
||||
# No structured final data here; node/adapter will assemble outputs
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
messages = cls.stream_online_results(
|
||||
user_id=user_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_type=datasource_type,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
datasource_param=datasource_param,
|
||||
online_drive_request=online_drive_request,
|
||||
)
|
||||
|
||||
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
|
||||
)
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
file_out: File | None = None
|
||||
|
||||
for message in transformed:
|
||||
mtype = message.type
|
||||
if mtype in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
wanted_ds_type = ds_type in {
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
}
|
||||
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
|
||||
url = message.message.text
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
|
||||
)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if not datasource_file:
|
||||
raise ValueError(
|
||||
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
|
||||
)
|
||||
mime_type = datasource_file.mimetype
|
||||
if datasource_file is not None:
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
elif mtype == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
name = message.message.variable_name
|
||||
value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
assert isinstance(value, str), "stream variable_value must be str"
|
||||
variables[name] = variables.get(name, "") + value
|
||||
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
|
||||
else:
|
||||
variables[name] = value
|
||||
elif mtype == DatasourceMessage.MessageType.FILE:
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
|
||||
f = message.meta.get("file")
|
||||
if isinstance(f, File):
|
||||
file_out = f
|
||||
else:
|
||||
pass
|
||||
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
|
||||
variable_pool.add([node_id, "file"], file_out)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={**variables},
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_out,
|
||||
"datasource_type": ds_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
return file_info
|
||||
|
||||
@@ -379,11 +379,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: str = Field("", description="The name of the bucket")
|
||||
|
||||
@field_validator("bucket", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bucket(cls, v) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
return str(v)
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
|
||||
@@ -1,26 +1,40 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any, cast
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.repositories.datasource_manager_protocol import (
|
||||
DatasourceManagerProtocol,
|
||||
DatasourceParameter,
|
||||
OnlineDriveDownloadFileParam,
|
||||
)
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node[DatasourceNodeData]):
|
||||
@@ -31,22 +45,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
@@ -54,69 +52,84 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segment:
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
|
||||
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segment:
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segment.value
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
|
||||
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=datasource_type.value,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
|
||||
# Build typed request objects
|
||||
datasource_parameters = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_parameters = DatasourceParameter(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
|
||||
online_drive_request = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
online_drive_request = OnlineDriveDownloadFileParam(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket", ""),
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
|
||||
credential_id = datasource_info.get("credential_id", "")
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=self.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=datasource_parameters,
|
||||
online_drive_request=online_drive_request,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -134,9 +147,23 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
@@ -174,6 +201,55 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@@ -211,6 +287,206 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
class DatasourceParameter(BaseModel):
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileParam(BaseModel):
|
||||
id: str
|
||||
bucket: str
|
||||
|
||||
|
||||
class DatasourceFinal(BaseModel):
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasourceManagerProtocol(Protocol):
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ...
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ...
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -80,12 +81,25 @@ class TokenPair(BaseModel):
|
||||
csrf_token: str
|
||||
|
||||
|
||||
class ChangeEmailPhase(StrEnum):
|
||||
OLD = "old_email"
|
||||
OLD_VERIFIED = "old_email_verified"
|
||||
NEW = "new_email"
|
||||
NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
|
||||
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@@ -535,13 +549,20 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
|
||||
|
||||
def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
|
||||
# produce a streamed variable "a"="xy"
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="x", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="y", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_accumulates_variables(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="A",
|
||||
user_id="u",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={"user_id": "u"},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=type("P", (), {"workspace_id": "w", "page_id": "pg", "type": "t"})(),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
@@ -1,84 +0,0 @@
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _Seg:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
|
||||
|
||||
class _VarPool:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def get(self, path):
|
||||
d = self.data
|
||||
for k in path:
|
||||
d = d[k]
|
||||
return _Seg(d)
|
||||
|
||||
def add(self, *_a, **_k):
|
||||
pass
|
||||
|
||||
|
||||
class _GS:
|
||||
def __init__(self, vp):
|
||||
self.variable_pool = vp
|
||||
|
||||
|
||||
class _GP:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_node_integration_minimal_stream(mocker):
|
||||
sys_d = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info": {"workspace_id": "w", "page": {"page_id": "pg", "type": "t"}, "credential_id": ""},
|
||||
}
|
||||
}
|
||||
vp = _VarPool(sys_d)
|
||||
|
||||
class _Mgr:
|
||||
@classmethod
|
||||
def get_icon_url(cls, **_):
|
||||
return "icon"
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(cls, **_):
|
||||
yield from ()
|
||||
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
graph_init_params=_GP(),
|
||||
graph_runtime_state=_GS(vp),
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
out = list(node._run())
|
||||
assert isinstance(out[-1], StreamCompletedEvent)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,418 +0,0 @@
|
||||
"""Integration tests for SQL-oriented DatasetService scenarios.
|
||||
|
||||
This suite migrates SQL-backed behaviors from the old unit suite to real
|
||||
container-backed integration tests. The tests exercise real ORM persistence and
|
||||
only patch non-DB collaborators when needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
|
||||
|
||||
class DatasetServiceIntegrationDataFactory:
|
||||
"""Factory for creating real database entities used by integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||
"""Create an account and tenant, then bind the account as current tenant member."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add_all([account, tenant])
|
||||
db.session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.flush()
|
||||
|
||||
# Keep tenant context on the in-memory user without opening a separate session.
|
||||
account.role = role
|
||||
account._current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str = "Test Dataset",
|
||||
description: str | None = "Test description",
|
||||
provider: str = "vendor",
|
||||
indexing_technique: str | None = "high_quality",
|
||||
permission: str = DatasetPermissionEnum.ONLY_ME,
|
||||
retrieval_model: dict | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
chunk_structure: str | None = None,
|
||||
) -> Dataset:
|
||||
"""Create a dataset record with configurable SQL fields."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
permission=permission,
|
||||
retrieval_model=retrieval_model,
|
||||
embedding_model_provider=embedding_model_provider,
|
||||
embedding_model=embedding_model,
|
||||
collection_binding_id=collection_binding_id,
|
||||
chunk_structure=chunk_structure,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document:
|
||||
"""Create a document row belonging to the given dataset."""
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_info='{"upload_file_id": "upload-file-id"}',
|
||||
batch=str(uuid4()),
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model(provider: str = "openai", model_name: str = "text-embedding-ada-002") -> Mock:
|
||||
"""Create a fake embedding model object for external provider boundary patching."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.provider = provider
|
||||
embedding_model.model_name = model_name
|
||||
return embedding_model
|
||||
|
||||
|
||||
class TestDatasetServiceCreateDataset:
|
||||
"""Integration coverage for DatasetService.create_empty_dataset."""
|
||||
|
||||
def test_create_internal_dataset_basic_success(self, db_session_with_containers):
|
||||
"""Create a basic internal dataset with minimal configuration."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Basic Internal Dataset",
|
||||
description="Test description",
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_dataset = db.session.get(Dataset, result.id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.provider == "vendor"
|
||||
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert created_dataset.embedding_model_provider is None
|
||||
assert created_dataset.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers):
|
||||
"""Create an internal dataset with economy indexing and no embedding model."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Economy Dataset",
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers):
|
||||
"""Create a high-quality dataset and persist embedding model settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.ModelManager") as mock_model_manager:
|
||||
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
|
||||
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="High Quality Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
def test_create_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
"""Raise duplicate-name error when the same tenant already has the name."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Duplicate Dataset",
|
||||
indexing_technique=None,
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Duplicate Dataset",
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
def test_create_external_dataset_success(self, db_session_with_containers):
|
||||
"""Create an external dataset and persist external knowledge binding."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
external_knowledge_id = "knowledge-123"
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api:
|
||||
mock_get_api.return_value = Mock(id=external_knowledge_api_id)
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="External Dataset",
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
|
||||
assert result.provider == "external"
|
||||
assert binding is not None
|
||||
assert binding.external_knowledge_id == external_knowledge_id
|
||||
assert binding.external_knowledge_api_id == external_knowledge_api_id
|
||||
|
||||
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers):
|
||||
"""Create a high-quality dataset with retrieval/reranking settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=True,
|
||||
reranking_model=RerankingModel(
|
||||
reranking_provider_name="cohere",
|
||||
reranking_model_name="rerank-english-v2.0",
|
||||
),
|
||||
top_k=3,
|
||||
score_threshold_enabled=True,
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
|
||||
):
|
||||
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
|
||||
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Dataset With Reranking",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.retrieval_model == retrieval_model.model_dump()
|
||||
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
"""Integration coverage for SQL-backed update and delete behavior."""
|
||||
|
||||
def test_update_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
"""Reject update when target name already exists within the same tenant."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Source Dataset",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Existing Dataset",
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Dataset name already exists"):
|
||||
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
|
||||
"""Delete a dataset that already has documents."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers):
|
||||
"""Delete a dataset that has no documents and no indexing technique."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
|
||||
"""Delete dataset when indexing_technique is None but doc_form path still exists."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
|
||||
class TestDatasetServiceRetrievalConfiguration:
|
||||
"""Integration coverage for retrieval configuration persistence."""
|
||||
|
||||
def test_get_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
"""Return retrieval configuration that is persisted in SQL."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 5,
|
||||
"score_threshold": 0.5,
|
||||
"reranking_enable": True,
|
||||
}
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.retrieval_model == retrieval_model
|
||||
assert result.retrieval_model["search_method"] == "semantic_search"
|
||||
assert result.retrieval_model["top_k"] == 5
|
||||
|
||||
def test_update_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
"""Persist retrieval configuration updates through DatasetService.update_dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=str(uuid4()),
|
||||
)
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": {
|
||||
"search_method": "full_text_search",
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, account)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(dataset)
|
||||
assert result.id == dataset.id
|
||||
assert dataset.retrieval_model == update_data["retrieval_model"]
|
||||
@@ -1,464 +0,0 @@
|
||||
"""
|
||||
Integration tests for document_indexing_sync_task using testcontainers.
|
||||
|
||||
This module validates SQL-backed behavior for document sync flows:
|
||||
- Notion sync precondition checks
|
||||
- Segment cleanup and document state updates
|
||||
- Credential and indexing error handling
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_dict_adapter_for_psycopg2():
|
||||
"""Align test DB adapter behavior with dict payloads used in task update flow."""
|
||||
register_adapter(dict, Json)
|
||||
|
||||
|
||||
class DocumentIndexingSyncTaskTestDataFactory:
|
||||
"""Create real DB entities for document indexing sync integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(db_session_with_containers, tenant_id: str, created_by: str) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
description="sync test dataset",
|
||||
data_source_type="notion_import",
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
data_source_info: dict | None,
|
||||
indexing_status: str = "completed",
|
||||
) -> Document:
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_info=json.dumps(data_source_info) if data_source_info is not None else None,
|
||||
batch="test-batch",
|
||||
name=f"doc-{uuid4()}",
|
||||
created_from="notion_import",
|
||||
created_by=created_by,
|
||||
indexing_status=indexing_status,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_segments(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
created_by: str,
|
||||
count: int = 3,
|
||||
) -> list[DocumentSegment]:
|
||||
segments: list[DocumentSegment] = []
|
||||
for i in range(count):
|
||||
segment = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
position=i,
|
||||
content=f"segment-{i}",
|
||||
answer=None,
|
||||
word_count=10,
|
||||
tokens=5,
|
||||
index_node_id=f"node-{document_id}-{i}",
|
||||
status="completed",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Integration tests for document_indexing_sync_task with real database assertions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_dependencies(self):
|
||||
"""Patch only external collaborators; keep DB access real."""
|
||||
with (
|
||||
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_datasource_service_class,
|
||||
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_notion_extractor_class,
|
||||
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_indexing_runner_class,
|
||||
):
|
||||
datasource_service = Mock()
|
||||
datasource_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_datasource_service_class.return_value = datasource_service
|
||||
|
||||
notion_extractor = Mock()
|
||||
notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_notion_extractor_class.return_value = notion_extractor
|
||||
|
||||
index_processor = Mock()
|
||||
index_processor.clean = Mock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = index_processor
|
||||
|
||||
indexing_runner = Mock(spec=IndexingRunner)
|
||||
indexing_runner.run = Mock()
|
||||
mock_indexing_runner_class.return_value = indexing_runner
|
||||
|
||||
yield {
|
||||
"datasource_service": datasource_service,
|
||||
"notion_extractor": notion_extractor,
|
||||
"notion_extractor_class": mock_notion_extractor_class,
|
||||
"index_processor": index_processor,
|
||||
"index_processor_factory": mock_index_processor_factory,
|
||||
"indexing_runner": indexing_runner,
|
||||
}
|
||||
|
||||
def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None):
|
||||
account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
notion_info = data_source_info or {
|
||||
"notion_workspace_id": str(uuid4()),
|
||||
"notion_page_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": str(uuid4()),
|
||||
}
|
||||
|
||||
document = DocumentIndexingSyncTaskTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=account.id,
|
||||
data_source_info=notion_info,
|
||||
indexing_status="completed",
|
||||
)
|
||||
|
||||
segments = DocumentIndexingSyncTaskTestDataFactory.create_segments(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
created_by=account.id,
|
||||
count=3,
|
||||
)
|
||||
|
||||
return {
|
||||
"account": account,
|
||||
"tenant": tenant,
|
||||
"dataset": dataset,
|
||||
"document": document,
|
||||
"segments": segments,
|
||||
"node_ids": [segment.index_node_id for segment in segments],
|
||||
"notion_info": notion_info,
|
||||
}
|
||||
|
||||
def test_document_not_found(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task handles missing document gracefully."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(
|
||||
db_session_with_containers,
|
||||
data_source_info={
|
||||
"notion_page_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(
|
||||
db_session_with_containers,
|
||||
data_source_info={
|
||||
"notion_workspace_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
|
||||
{"data_source_info": None}
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task sets document error state when credential is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["datasource_service"].get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task exits early when notion page is unchanged."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["notion_extractor"].get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.processing_started_at is None
|
||||
assert remaining_segments == 3
|
||||
mock_external_dependencies["index_processor"].clean.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test full successful sync flow with SQL state updates and side effects."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z"
|
||||
assert remaining_segments == 0
|
||||
|
||||
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
|
||||
assert clean_call_args is not None
|
||||
clean_args, clean_kwargs = clean_call_args
|
||||
assert getattr(clean_args[0], "id", None) == context["dataset"].id
|
||||
assert set(clean_args[1]) == set(context["node_ids"])
|
||||
assert clean_kwargs.get("with_keywords") is True
|
||||
assert clean_kwargs.get("delete_child_chunks") is True
|
||||
|
||||
run_call_args = mock_external_dependencies["indexing_runner"].run.call_args
|
||||
assert run_call_args is not None
|
||||
run_documents = run_call_args[0][0]
|
||||
assert len(run_documents) == 1
|
||||
assert getattr(run_documents[0], "id", None) == context["document"].id
|
||||
|
||||
def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task still updates document and reindexes if dataset vanishes before clean."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
def _delete_dataset_before_clean() -> str:
|
||||
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
|
||||
db_session_with_containers.commit()
|
||||
return "2024-01-02T00:00:00Z"
|
||||
|
||||
mock_external_dependencies[
|
||||
"notion_extractor"
|
||||
].get_notion_last_edited_time.side_effect = _delete_dataset_before_clean
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
mock_external_dependencies["index_processor"].clean.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that indexing continues when index cleanup fails."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["index_processor"].clean.side_effect = Exception("Cleaning error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert remaining_segments == 0
|
||||
mock_external_dependencies["indexing_runner"].run.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that DocumentIsPausedError does not flip document into error state."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["indexing_runner"].run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.error is None
|
||||
|
||||
def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that indexing errors are persisted to document state."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["indexing_runner"].run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert "Indexing error" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
mock_external_dependencies,
|
||||
):
|
||||
"""Test that clean is called with dataset instance and collected node ids."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
|
||||
assert clean_call_args is not None
|
||||
clean_args, clean_kwargs = clean_call_args
|
||||
assert getattr(clean_args[0], "id", None) == context["dataset"].id
|
||||
assert set(clean_args[1]) == set(context["node_ids"])
|
||||
assert clean_kwargs.get("with_keywords") is True
|
||||
assert clean_kwargs.get("delete_child_chunks") is True
|
||||
@@ -77,7 +77,7 @@ def _restx_mask_defaults(app: Flask):
|
||||
|
||||
|
||||
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
service_result = [{"entrypoint": "main:agent"}]
|
||||
service_result = {"entrypoint": "main:agent"}
|
||||
service_mock = MagicMock(return_value=service_result)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
def test_should_infer_new_email_phase_from_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.db")
|
||||
@patch("controllers.console.workspace.account.Session")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account_by_email,
|
||||
mock_session_cls,
|
||||
mock_account_db,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
|
||||
existing_account = _build_account("old@example.com", "acc-old")
|
||||
mock_get_account_by_email.return_value = existing_account
|
||||
mock_send_email.return_value = "token-legacy"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__.return_value = mock_session
|
||||
mock_session_cm.__exit__.return_value = None
|
||||
mock_session_cls.return_value = mock_session_cm
|
||||
mock_account_db.engine = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-legacy"}
|
||||
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=existing_account,
|
||||
email="old@example.com",
|
||||
old_email="old@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_unverified_old_email_token_for_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_refresh_new_email_phase_to_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("old@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "5678",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-phase-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("new@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-456")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="5678",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "new@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_phase_token_for_reset(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_mismatched_new_email_for_verified_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "another@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-789"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(text=text),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_icon_url_calls_runtime(mocker):
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_icon_url.return_value = "https://icon"
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
|
||||
url = DatasourceManager.get_icon_url(
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
)
|
||||
assert url == "https://icon"
|
||||
DatasourceManager.get_datasource_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
# stub runtime to yield a text message
|
||||
def _doc_messages(**_):
|
||||
yield from _gen_messages_text_only("hello")
|
||||
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_online_document_page_content.side_effect = _doc_messages
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
msgs = list(gen)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0].message.text == "hello"
|
||||
|
||||
|
||||
def test_stream_node_events_emits_events_online_document(mocker):
|
||||
# make manager's low-level stream produce TEXT only
|
||||
mocker.patch.object(
|
||||
DatasourceManager,
|
||||
"stream_online_results",
|
||||
return_value=_gen_messages_text_only("hello"),
|
||||
)
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={"k": "v"},
|
||||
datasource_info={"user_id": "u1"},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
# should contain one StreamChunkEvent then a final chunk (empty) and a completed event
|
||||
assert isinstance(events[0], StreamChunkEvent)
|
||||
assert events[0].chunk == "hello"
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_builds_file(mocker):
|
||||
# fake UploadFile row
|
||||
fake_row = types.SimpleNamespace(
|
||||
id="fid",
|
||||
name="f",
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
size=1,
|
||||
key="k",
|
||||
source_url="http://x",
|
||||
)
|
||||
|
||||
class _Q:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._row
|
||||
|
||||
class _S:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q(self._row)
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
||||
|
||||
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
assert f.related_id == "fid"
|
||||
assert f.extension == ".txt"
|
||||
@@ -1,93 +0,0 @@
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _VarSeg:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
|
||||
|
||||
class _VarPool:
|
||||
def __init__(self, mapping):
|
||||
self._m = mapping
|
||||
|
||||
def get(self, selector):
|
||||
d = self._m
|
||||
for k in selector:
|
||||
d = d[k]
|
||||
return _VarSeg(d)
|
||||
|
||||
def add(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class _GraphState:
|
||||
def __init__(self, var_pool):
|
||||
self.variable_pool = var_pool
|
||||
|
||||
|
||||
class _GraphParams:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_datasource_node_delegates_to_manager_stream(mocker):
|
||||
# prepare sys variables
|
||||
sys_vars = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info": {
|
||||
"workspace_id": "w",
|
||||
"page": {"page_id": "pg", "type": "t"},
|
||||
"credential_id": "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var_pool = _VarPool(sys_vars)
|
||||
gs = _GraphState(var_pool)
|
||||
gp = _GraphParams()
|
||||
|
||||
# stub manager class
|
||||
class _Mgr:
|
||||
@classmethod
|
||||
def get_icon_url(cls, **_):
|
||||
return "icon"
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(cls, **_):
|
||||
yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False)
|
||||
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError("not called")
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
graph_init_params=gp,
|
||||
graph_runtime_state=gs,
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
evts = list(node._run())
|
||||
assert isinstance(evts[0], StreamChunkEvent)
|
||||
assert isinstance(evts[-1], StreamCompletedEvent)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Unit tests for collaborator parameter wiring in document_indexing_sync_task.
|
||||
Unit tests for document indexing sync task.
|
||||
|
||||
These tests intentionally stay in unit scope because they validate call arguments
|
||||
for external collaborators rather than SQL-backed state transitions.
|
||||
This module tests the document indexing sync task functionality including:
|
||||
- Syncing Notion documents when updated
|
||||
- Validating document and data source existence
|
||||
- Credential validation and retrieval
|
||||
- Cleaning old segments before re-indexing
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -10,92 +14,187 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id() -> str:
|
||||
"""Generate a dataset id."""
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id() -> str:
|
||||
"""Generate a document id."""
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_workspace_id() -> str:
|
||||
"""Generate a notion workspace id."""
|
||||
def document_id():
|
||||
"""Generate a unique document ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_page_id() -> str:
|
||||
"""Generate a notion page id."""
|
||||
def notion_workspace_id():
|
||||
"""Generate a Notion workspace ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_id() -> str:
|
||||
"""Generate a credential id."""
|
||||
def notion_page_id():
|
||||
"""Generate a Notion page ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id):
|
||||
"""Create a minimal dataset mock used by the task pre-check."""
|
||||
def credential_id():
|
||||
"""Generate a credential ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a minimal notion document mock for collaborator parameter assertions."""
|
||||
document = Mock(spec=Document)
|
||||
document.id = document_id
|
||||
document.dataset_id = dataset_id
|
||||
document.tenant_id = str(uuid.uuid4())
|
||||
document.data_source_type = "notion_import"
|
||||
document.indexing_status = "completed"
|
||||
document.doc_form = "text_model"
|
||||
document.data_source_info_dict = {
|
||||
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a mock Document object with Notion data source."""
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.data_source_type = "notion_import"
|
||||
doc.indexing_status = "completed"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
doc.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": credential_id,
|
||||
}
|
||||
return document
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mock_document, mock_dataset):
|
||||
"""Mock session_factory.create_session to drive deterministic read-only task flow."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory:
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
def mock_document_segments(document_id):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = document_id
|
||||
segment.index_node_id = f"node-{document_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
begin_cm.__exit__.return_value = False
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session().
|
||||
|
||||
mock_session_factory.create_session.return_value = session_cm
|
||||
yield session
|
||||
After session split refactor, the code calls create_session() multiple times.
|
||||
This fixture creates shared query mocks so all sessions use the same
|
||||
query configuration, simulating database persistence across sessions.
|
||||
|
||||
The fixture automatically converts side_effect to cycle to prevent StopIteration.
|
||||
Tests configure mocks the same way as before, but behind the scenes the values
|
||||
are cycled infinitely for all sessions.
|
||||
"""
|
||||
from itertools import cycle
|
||||
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
sessions = []
|
||||
|
||||
# Shared query mocks - all sessions use these
|
||||
shared_query = MagicMock()
|
||||
shared_filter_by = MagicMock()
|
||||
shared_scalars_result = MagicMock()
|
||||
|
||||
# Create custom first mock that auto-cycles side_effect
|
||||
class CyclicMock(MagicMock):
|
||||
def __setattr__(self, name, value):
|
||||
if name == "side_effect" and value is not None:
|
||||
# Convert list/tuple to infinite cycle
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = cycle(value)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
shared_query.where.return_value.first = CyclicMock()
|
||||
shared_filter_by.first = CyclicMock()
|
||||
|
||||
def _create_session():
|
||||
"""Create a new mock session for each create_session() call."""
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def _begin_exit_side_effect(exc_type, exc, tb):
|
||||
# commit on success
|
||||
if exc_type is None:
|
||||
session.commit()
|
||||
# return False to propagate exceptions
|
||||
return False
|
||||
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(exc_type, exc, tb):
|
||||
session.close()
|
||||
return False
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# All sessions use the same shared query mocks
|
||||
session.query.return_value = shared_query
|
||||
shared_query.where.return_value = shared_query
|
||||
shared_query.filter_by.return_value = shared_filter_by
|
||||
session.scalars.return_value = shared_scalars_result
|
||||
|
||||
sessions.append(session)
|
||||
# Attach helpers on the first created session for assertions across all sessions
|
||||
if len(sessions) == 1:
|
||||
session.get_all_sessions = lambda: sessions
|
||||
session.any_close_called = lambda: any(s.close.called for s in sessions)
|
||||
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
|
||||
# Create first session and return it
|
||||
_create_session()
|
||||
yield sessions[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock datasource credential provider."""
|
||||
"""Mock DatasourceProviderService."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
@@ -105,16 +204,314 @@ def mock_datasource_provider_service():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock notion extractor class and instance."""
|
||||
"""Mock NotionExtractor."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
yield {"class": mock_extractor_class, "instance": mock_extractor}
|
||||
yield mock_extractor
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
"""Unit tests for collaborator parameter passing in document_indexing_sync_task."""
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner.run = Mock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for document_indexing_sync_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Tests for the document_indexing_sync_task function."""
|
||||
|
||||
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||
"""Test that task handles document not found gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert - at least one session should have been closed
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = None
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_credential_not_found(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# At least one session should have been closed via context manager teardown
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
# Set exact sequence of returns across calls to `.first()`:
|
||||
# 1) document (initial fetch)
|
||||
# 2) dataset (pre-check)
|
||||
# 3) dataset (cleaning phase)
|
||||
# 4) document (pre-indexing update)
|
||||
# 5) document (indexing runner fetch)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Verify document status was updated to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
assert mock_document.processing_started_at is not None
|
||||
|
||||
# Verify segments were cleaned
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
# Aggregate execute calls across all created sessions
|
||||
execute_sqls = []
|
||||
for s in mock_db_session.get_all_sessions():
|
||||
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations (across any created session)
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_indexing_runner,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
None,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# At least one session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Make the cleaning step fail but not the segment fetch
|
||||
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
processor.clean.side_effect = Exception("Cleaning error")
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
@@ -127,21 +524,27 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
notion_workspace_id,
|
||||
notion_page_id,
|
||||
):
|
||||
"""Test that NotionExtractor is initialized with expected arguments."""
|
||||
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||
# Arrange
|
||||
expected_token = "test_token"
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
# Assert
|
||||
mock_notion_extractor["class"].assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token=expected_token,
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_extractor_class.assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token="test_token",
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
|
||||
def test_datasource_credentials_requested_correctly(
|
||||
self,
|
||||
@@ -153,16 +556,17 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
document_id,
|
||||
credential_id,
|
||||
):
|
||||
"""Test that datasource credentials are requested with expected identifiers."""
|
||||
"""Test that datasource credentials are requested with correct parameters."""
|
||||
# Arrange
|
||||
expected_tenant_id = mock_document.tenant_id
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=expected_tenant_id,
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@@ -177,14 +581,16 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that missing credential_id is forwarded as None."""
|
||||
"""Test that task handles missing credential_id by passing None."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {
|
||||
"notion_workspace_id": "workspace-id",
|
||||
"notion_page_id": "page-id",
|
||||
"notion_workspace_id": "ws123",
|
||||
"notion_page_id": "page123",
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
@@ -196,3 +602,39 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||
mock_processor.clean.assert_called_once_with(
|
||||
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||
)
|
||||
|
||||
6
api/uv.lock
generated
6
api/uv.lock
generated
@@ -5049,11 +5049,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.7.4"
|
||||
version = "6.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ff/63/3437c4363483f2a04000a48f1cd48c40097f69d580363712fa8b0b4afe45/pypdf-6.7.1.tar.gz", hash = "sha256:6b7a63be5563a0a35d54c6d6b550d75c00b8ccf36384be96365355e296e6b3b0", size = 5302208, upload-time = "2026-02-17T17:00:48.88Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/77/38bd7744bb9e06d465b0c23879e6d2c187d93a383f8fa485c862822bb8a3/pypdf-6.7.1-py3-none-any.whl", hash = "sha256:a02ccbb06463f7c334ce1612e91b3e68a8e827f3cee100b9941771e6066b094e", size = 331048, upload-time = "2026-02-17T17:00:46.991Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -33,7 +33,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
```txt
|
||||
```
|
||||
# For production release, change this to PRODUCTION
|
||||
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
|
||||
# The deployment edition, SELF_HOSTED
|
||||
|
||||
@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
const sendEmail = async (email: string, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
|
||||
@@ -61,7 +61,8 @@ const ParamsConfig = ({
|
||||
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
||||
if (tempDataSetConfigs.reranking_enable
|
||||
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !isCurrentRerankModelValid) {
|
||||
&& !isCurrentRerankModelValid
|
||||
) {
|
||||
errMsg = t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ describe('Input component', () => {
|
||||
|
||||
it('shows left icon when showLeftIcon is true', () => {
|
||||
render(<Input showLeftIcon />)
|
||||
const searchIcon = document.querySelector('.i-ri-search-line')
|
||||
const searchIcon = document.querySelector('svg')
|
||||
expect(searchIcon).toBeInTheDocument()
|
||||
const input = screen.getByPlaceholderText('Search')
|
||||
expect(input).toHaveClass('pl-[26px]')
|
||||
@@ -51,7 +51,7 @@ describe('Input component', () => {
|
||||
|
||||
it('shows clear icon when showClearIcon is true and has value', () => {
|
||||
render(<Input showClearIcon value="test" />)
|
||||
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
|
||||
const clearIcon = document.querySelector('.group svg')
|
||||
expect(clearIcon).toBeInTheDocument()
|
||||
const input = screen.getByDisplayValue('test')
|
||||
expect(input).toHaveClass('pr-[26px]')
|
||||
@@ -59,21 +59,21 @@ describe('Input component', () => {
|
||||
|
||||
it('does not show clear icon when disabled, even with value', () => {
|
||||
render(<Input showClearIcon value="test" disabled />)
|
||||
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
|
||||
const clearIcon = document.querySelector('.group svg')
|
||||
expect(clearIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calls onClear when clear icon is clicked', () => {
|
||||
const onClear = vi.fn()
|
||||
render(<Input showClearIcon value="test" onClear={onClear} />)
|
||||
const clearIconContainer = screen.getByTestId('input-clear')
|
||||
const clearIconContainer = document.querySelector('.group')
|
||||
fireEvent.click(clearIconContainer!)
|
||||
expect(onClear).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('shows warning icon when destructive is true', () => {
|
||||
render(<Input destructive />)
|
||||
const warningIcon = document.querySelector('.i-ri-error-warning-line')
|
||||
const warningIcon = document.querySelector('svg')
|
||||
expect(warningIcon).toBeInTheDocument()
|
||||
const input = screen.getByPlaceholderText('Please input')
|
||||
expect(input).toHaveClass('border-components-input-border-destructive')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { VariantProps } from 'class-variance-authority'
|
||||
import type { ChangeEventHandler, CSSProperties, FocusEventHandler } from 'react'
|
||||
import { RiCloseCircleFill, RiErrorWarningLine, RiSearchLine } from '@remixicon/react'
|
||||
import { cva } from 'class-variance-authority'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import * as React from 'react'
|
||||
@@ -12,8 +13,8 @@ export const inputVariants = cva(
|
||||
{
|
||||
variants: {
|
||||
size: {
|
||||
regular: 'px-3 system-sm-regular radius-md',
|
||||
large: 'px-4 system-md-regular radius-lg',
|
||||
regular: 'px-3 radius-md system-sm-regular',
|
||||
large: 'px-4 radius-lg system-md-regular',
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
@@ -82,7 +83,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
}
|
||||
return (
|
||||
<div className={cn('relative w-full', wrapperClassName)}>
|
||||
{showLeftIcon && <span className={cn('i-ri-search-line absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
|
||||
{showLeftIcon && <RiSearchLine className={cn('absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
|
||||
<input
|
||||
ref={ref}
|
||||
style={styleCss}
|
||||
@@ -114,11 +115,11 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
onClick={onClear}
|
||||
data-testid="input-clear"
|
||||
>
|
||||
<span className="i-ri-close-circle-fill h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
|
||||
<RiCloseCircleFill className="h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
|
||||
</div>
|
||||
)}
|
||||
{destructive && (
|
||||
<span className="i-ri-error-warning-line absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
|
||||
<RiErrorWarningLine className="absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
|
||||
)}
|
||||
{showCopyIcon && (
|
||||
<div className={cn('group absolute right-0 top-1/2 -translate-y-1/2 cursor-pointer')}>
|
||||
@@ -130,7 +131,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
)}
|
||||
{
|
||||
unit && (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary system-sm-regular">
|
||||
<div className="system-sm-regular absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary">
|
||||
{unit}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -334,10 +334,10 @@ describe('FileList', () => {
|
||||
it('should call resetKeywords prop when clear button is clicked', () => {
|
||||
const mockResetKeywords = vi.fn()
|
||||
const props = createDefaultProps({ resetKeywords: mockResetKeywords, keywords: 'to-reset' })
|
||||
render(<FileList {...props} />)
|
||||
const { container } = render(<FileList {...props} />)
|
||||
|
||||
// Act - Click the clear icon div (it contains RiCloseCircleFill icon)
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
@@ -346,12 +346,12 @@ describe('FileList', () => {
|
||||
|
||||
it('should reset inputValue to empty string when clear is clicked', () => {
|
||||
const props = createDefaultProps({ keywords: 'to-be-reset' })
|
||||
render(<FileList {...props} />)
|
||||
const { container } = render(<FileList {...props} />)
|
||||
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
|
||||
fireEvent.change(input, { target: { value: 'some-search' } })
|
||||
|
||||
// Act - Find and click the clear icon
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ describe('Header', () => {
|
||||
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Assert - Input should have search icon class
|
||||
const searchIcon = container.querySelector('.i-ri-search-line.h-4.w-4')
|
||||
// Assert - Input should have search icon (RiSearchLine is rendered as svg)
|
||||
const searchIcon = container.querySelector('svg.h-4.w-4')
|
||||
expect(searchIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -313,10 +313,10 @@ describe('Header', () => {
|
||||
inputValue: 'to-clear',
|
||||
handleResetKeywords: mockHandleResetKeywords,
|
||||
})
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act - Find and click the clear icon container
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
@@ -325,19 +325,19 @@ describe('Header', () => {
|
||||
|
||||
it('should not show clear icon when inputValue is empty', () => {
|
||||
const props = createDefaultProps({ inputValue: '' })
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act & Assert - Clear icon should not be visible
|
||||
const clearIcon = screen.queryByTestId('input-clear')
|
||||
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
|
||||
expect(clearIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show clear icon when inputValue is not empty', () => {
|
||||
const props = createDefaultProps({ inputValue: 'some-value' })
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act & Assert - Clear icon should be visible
|
||||
const clearIcon = screen.getByTestId('input-clear')
|
||||
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
|
||||
expect(clearIcon).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -570,12 +570,13 @@ describe('Header', () => {
|
||||
inputValue: 'to-clear',
|
||||
handleResetKeywords: mockHandleResetKeywords,
|
||||
})
|
||||
const { rerender } = render(<Header {...props} />)
|
||||
const { container, rerender } = render(<Header {...props} />)
|
||||
|
||||
// Act - Click clear, rerender, click again
|
||||
fireEvent.click(screen.getByTestId('input-clear'))
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
fireEvent.click(clearButton!)
|
||||
rerender(<Header {...props} />)
|
||||
fireEvent.click(screen.getByTestId('input-clear'))
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
expect(mockHandleResetKeywords).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
'use client'
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import {
|
||||
RiBrain2Fill,
|
||||
RiBrain2Line,
|
||||
RiCloseLine,
|
||||
RiColorFilterFill,
|
||||
RiColorFilterLine,
|
||||
RiDatabase2Fill,
|
||||
RiDatabase2Line,
|
||||
RiGroup2Fill,
|
||||
RiGroup2Line,
|
||||
RiMoneyDollarCircleFill,
|
||||
RiMoneyDollarCircleLine,
|
||||
RiPuzzle2Fill,
|
||||
RiPuzzle2Line,
|
||||
RiTranslate2,
|
||||
} from '@remixicon/react'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SearchInput from '@/app/components/base/search-input'
|
||||
import Input from '@/app/components/base/input'
|
||||
import BillingPage from '@/app/components/billing/billing-page'
|
||||
import CustomPage from '@/app/components/custom/custom-page'
|
||||
import {
|
||||
@@ -60,14 +76,14 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.PROVIDER,
|
||||
name: t('settings.provider', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-brain-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-brain-2-fill', iconClassName)} />,
|
||||
icon: <RiBrain2Line className={iconClassName} />,
|
||||
activeIcon: <RiBrain2Fill className={iconClassName} />,
|
||||
},
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
name: t('settings.members', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-group-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-group-2-fill', iconClassName)} />,
|
||||
icon: <RiGroup2Line className={iconClassName} />,
|
||||
activeIcon: <RiGroup2Fill className={iconClassName} />,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -76,8 +92,8 @@ export default function AccountSetting({
|
||||
key: ACCOUNT_SETTING_TAB.BILLING,
|
||||
name: t('settings.billing', { ns: 'common' }),
|
||||
description: t('plansCommon.receiptInfo', { ns: 'billing' }),
|
||||
icon: <span className={cn('i-ri-money-dollar-circle-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-money-dollar-circle-fill', iconClassName)} />,
|
||||
icon: <RiMoneyDollarCircleLine className={iconClassName} />,
|
||||
activeIcon: <RiMoneyDollarCircleFill className={iconClassName} />,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -85,14 +101,14 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.DATA_SOURCE,
|
||||
name: t('settings.dataSource', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-database-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-database-2-fill', iconClassName)} />,
|
||||
icon: <RiDatabase2Line className={iconClassName} />,
|
||||
activeIcon: <RiDatabase2Fill className={iconClassName} />,
|
||||
},
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.API_BASED_EXTENSION,
|
||||
name: t('settings.apiBasedExtension', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-puzzle-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-puzzle-2-fill', iconClassName)} />,
|
||||
icon: <RiPuzzle2Line className={iconClassName} />,
|
||||
activeIcon: <RiPuzzle2Fill className={iconClassName} />,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -100,8 +116,8 @@ export default function AccountSetting({
|
||||
items.push({
|
||||
key: ACCOUNT_SETTING_TAB.CUSTOM,
|
||||
name: t('custom', { ns: 'custom' }),
|
||||
icon: <span className={cn('i-ri-color-filter-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-color-filter-fill', iconClassName)} />,
|
||||
icon: <RiColorFilterLine className={iconClassName} />,
|
||||
activeIcon: <RiColorFilterFill className={iconClassName} />,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -124,8 +140,8 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.LANGUAGE,
|
||||
name: t('settings.language', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-translate-2', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-translate-2', iconClassName)} />,
|
||||
icon: <RiTranslate2 className={iconClassName} />,
|
||||
activeIcon: <RiTranslate2 className={iconClassName} />,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -155,13 +171,13 @@ export default function AccountSetting({
|
||||
>
|
||||
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
|
||||
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
|
||||
<div className="mb-8 mt-6 px-3 py-2 text-text-primary title-2xl-semi-bold">{t('userProfile.settings', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold mb-8 mt-6 px-3 py-2 text-text-primary">{t('userProfile.settings', { ns: 'common' })}</div>
|
||||
<div className="w-full">
|
||||
{
|
||||
menuItems.map(menuItem => (
|
||||
<div key={menuItem.key} className="mb-2">
|
||||
{!isCurrentWorkspaceDatasetOperator && (
|
||||
<div className="mb-0.5 py-2 pb-1 pl-3 text-text-tertiary system-xs-medium-uppercase">{menuItem.name}</div>
|
||||
<div className="system-xs-medium-uppercase mb-0.5 py-2 pb-1 pl-3 text-text-tertiary">{menuItem.name}</div>
|
||||
)}
|
||||
<div>
|
||||
{
|
||||
@@ -170,7 +186,7 @@ export default function AccountSetting({
|
||||
key={item.key}
|
||||
className={cn(
|
||||
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
|
||||
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
|
||||
activeMenu === item.key ? 'system-sm-semibold bg-state-base-active text-components-menu-item-text-active' : 'system-sm-medium text-components-menu-item-text',
|
||||
)}
|
||||
title={item.name}
|
||||
onClick={() => {
|
||||
@@ -197,36 +213,38 @@ export default function AccountSetting({
|
||||
className="px-2"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<span className="i-ri-close-line h-5 w-5" />
|
||||
<RiCloseLine className="h-5 w-5" />
|
||||
</Button>
|
||||
<div className="mt-1 text-text-tertiary system-2xs-medium-uppercase">ESC</div>
|
||||
<div className="system-2xs-medium-uppercase mt-1 text-text-tertiary">ESC</div>
|
||||
</div>
|
||||
<div ref={scrollRef} className="w-full overflow-y-auto bg-components-panel-bg pb-4">
|
||||
<div className={cn('sticky top-0 z-20 mx-8 mb-[18px] flex items-center bg-components-panel-bg pb-2 pt-[27px]', scrolled && 'border-b border-divider-regular')}>
|
||||
<div className="shrink-0 text-text-primary title-2xl-semi-bold">
|
||||
<div className="title-2xl-semi-bold shrink-0 text-text-primary">
|
||||
{activeItem?.name}
|
||||
{activeItem?.description && (
|
||||
<div className="mt-1 text-text-tertiary system-sm-regular">{activeItem?.description}</div>
|
||||
<div className="system-sm-regular mt-1 text-text-tertiary">{activeItem?.description}</div>
|
||||
)}
|
||||
</div>
|
||||
{activeItem?.key === ACCOUNT_SETTING_TAB.PROVIDER && (
|
||||
{activeItem?.key === 'provider' && (
|
||||
<div className="flex grow justify-end">
|
||||
<SearchInput
|
||||
className="w-[200px]"
|
||||
onChange={setSearchValue}
|
||||
<Input
|
||||
showLeftIcon
|
||||
wrapperClassName="!w-[200px]"
|
||||
className="!h-8 !text-[13px]"
|
||||
onChange={e => setSearchValue(e.target.value)}
|
||||
value={searchValue}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="px-4 pt-2 sm:px-8">
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.PROVIDER && <ModelProviderPage searchText={searchValue} />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.MEMBERS && <MembersPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.BILLING && <BillingPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.DATA_SOURCE && <DataSourcePage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.API_BASED_EXTENSION && <ApiBasedExtensionPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && <CustomPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && <LanguagePage />}
|
||||
{activeMenu === 'provider' && <ModelProviderPage searchText={searchValue} />}
|
||||
{activeMenu === 'members' && <MembersPage />}
|
||||
{activeMenu === 'billing' && <BillingPage />}
|
||||
{activeMenu === 'data-source' && <DataSourcePage />}
|
||||
{activeMenu === 'api-based-extension' && <ApiBasedExtensionPage />}
|
||||
{activeMenu === 'custom' && <CustomPage />}
|
||||
{activeMenu === 'language' && <LanguagePage />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -46,7 +46,8 @@ const nodeDefault: NodeDefault<HttpNodeType> = {
|
||||
|
||||
if (!errorMessages
|
||||
&& payload.body.type === BodyType.binary
|
||||
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)) {
|
||||
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)
|
||||
) {
|
||||
errorMessages = t('errorMsg.fieldRequired', { ns: 'workflow', field: t('nodes.http.binaryFileVariable', { ns: 'workflow' }) })
|
||||
}
|
||||
|
||||
|
||||
@@ -2052,6 +2052,9 @@
|
||||
"app/components/base/input/index.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/base/linked-apps-panel/index.tsx": {
|
||||
@@ -3991,6 +3994,9 @@
|
||||
"app/components/header/account-setting/index.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 7
|
||||
}
|
||||
},
|
||||
"app/components/header/account-setting/key-validator/declarations.ts": {
|
||||
@@ -8163,6 +8169,11 @@
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"i18n-config/README.md": {
|
||||
"no-irregular-whitespace": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"i18n/de-DE/billing.json": {
|
||||
"no-irregular-whitespace": {
|
||||
"count": 1
|
||||
|
||||
@@ -6,7 +6,7 @@ This directory contains i18n tooling and configuration. Translation files live u
|
||||
|
||||
## File Structure
|
||||
|
||||
```txt
|
||||
```
|
||||
web/i18n
|
||||
├── en-US
|
||||
│ ├── app.json
|
||||
@@ -36,7 +36,7 @@ By default we will use `LanguagesSupported` to determine which languages are sup
|
||||
|
||||
1. Create a new folder for the new language.
|
||||
|
||||
```txt
|
||||
```
|
||||
cd web/i18n
|
||||
cp -r en-US id-ID
|
||||
```
|
||||
@@ -98,7 +98,7 @@ export const languages = [
|
||||
{
|
||||
value: 'ru-RU',
|
||||
name: 'Русский(Россия)',
|
||||
example: 'Привет, Dify!',
|
||||
example: ' Привет, Dify!',
|
||||
supported: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"type": "module",
|
||||
"version": "1.13.0",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.27.0",
|
||||
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
|
||||
"imports": {
|
||||
"#i18n": {
|
||||
"react-server": "./i18n-config/lib.server.ts",
|
||||
@@ -165,10 +165,10 @@
|
||||
"zustand": "5.0.9"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@antfu/eslint-config": "7.6.1",
|
||||
"@chromatic-com/storybook": "5.0.1",
|
||||
"@antfu/eslint-config": "7.2.0",
|
||||
"@chromatic-com/storybook": "5.0.0",
|
||||
"@egoist/tailwindcss-icons": "1.9.2",
|
||||
"@eslint-react/eslint-plugin": "2.13.0",
|
||||
"@eslint-react/eslint-plugin": "2.9.4",
|
||||
"@iconify-json/heroicons": "1.2.3",
|
||||
"@iconify-json/ri": "1.2.9",
|
||||
"@mdx-js/loader": "3.1.1",
|
||||
@@ -177,12 +177,12 @@
|
||||
"@next/mdx": "16.1.5",
|
||||
"@rgrove/parse-xml": "4.2.0",
|
||||
"@serwist/turbopack": "9.5.4",
|
||||
"@storybook/addon-docs": "10.2.13",
|
||||
"@storybook/addon-links": "10.2.13",
|
||||
"@storybook/addon-onboarding": "10.2.13",
|
||||
"@storybook/addon-themes": "10.2.13",
|
||||
"@storybook/nextjs-vite": "10.2.13",
|
||||
"@storybook/react": "10.2.13",
|
||||
"@storybook/addon-docs": "10.2.0",
|
||||
"@storybook/addon-links": "10.2.0",
|
||||
"@storybook/addon-onboarding": "10.2.0",
|
||||
"@storybook/addon-themes": "10.2.0",
|
||||
"@storybook/nextjs-vite": "10.2.0",
|
||||
"@storybook/react": "10.2.0",
|
||||
"@tanstack/eslint-plugin-query": "5.91.4",
|
||||
"@tanstack/react-devtools": "0.9.2",
|
||||
"@tanstack/react-form-devtools": "0.2.12",
|
||||
@@ -208,7 +208,7 @@
|
||||
"@types/semver": "7.7.1",
|
||||
"@types/sortablejs": "1.15.8",
|
||||
"@types/uuid": "10.0.0",
|
||||
"@typescript-eslint/parser": "8.56.1",
|
||||
"@typescript-eslint/parser": "8.54.0",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251209.1",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"@vitest/coverage-v8": "4.0.17",
|
||||
@@ -216,13 +216,13 @@
|
||||
"code-inspector-plugin": "1.3.6",
|
||||
"cross-env": "10.1.0",
|
||||
"esbuild": "0.27.2",
|
||||
"eslint": "10.0.2",
|
||||
"eslint-plugin-better-tailwindcss": "4.3.1",
|
||||
"eslint-plugin-hyoban": "0.11.2",
|
||||
"eslint": "9.39.2",
|
||||
"eslint-plugin-better-tailwindcss": "https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@c0161c7",
|
||||
"eslint-plugin-hyoban": "0.11.1",
|
||||
"eslint-plugin-react-hooks": "7.0.1",
|
||||
"eslint-plugin-react-refresh": "0.5.2",
|
||||
"eslint-plugin-sonarjs": "4.0.0",
|
||||
"eslint-plugin-storybook": "10.2.13",
|
||||
"eslint-plugin-react-refresh": "0.5.0",
|
||||
"eslint-plugin-sonarjs": "3.0.6",
|
||||
"eslint-plugin-storybook": "10.2.6",
|
||||
"husky": "9.1.7",
|
||||
"iconify-import-svg": "0.1.1",
|
||||
"jsdom": "27.3.0",
|
||||
@@ -235,7 +235,7 @@
|
||||
"react-scan": "0.4.3",
|
||||
"sass": "1.93.2",
|
||||
"serwist": "9.5.4",
|
||||
"storybook": "10.2.13",
|
||||
"storybook": "10.2.0",
|
||||
"tailwindcss": "3.4.19",
|
||||
"tsx": "4.21.0",
|
||||
"typescript": "5.9.3",
|
||||
@@ -249,7 +249,6 @@
|
||||
"overrides": {
|
||||
"@monaco-editor/loader": "1.5.0",
|
||||
"@nolyfill/safe-buffer": "npm:safe-buffer@^5.2.1",
|
||||
"@stylistic/eslint-plugin": "https://pkg.pr.new/@stylistic/eslint-plugin@258f9d8",
|
||||
"array-includes": "npm:@nolyfill/array-includes@^1",
|
||||
"array.prototype.findlast": "npm:@nolyfill/array.prototype.findlast@^1",
|
||||
"array.prototype.findlastindex": "npm:@nolyfill/array.prototype.findlastindex@^1",
|
||||
|
||||
1883
web/pnpm-lock.yaml
generated
1883
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
|
||||
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
|
||||
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
|
||||
|
||||
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
post<CommonResponse & { data: string }>('/account/change-email', { body })
|
||||
|
||||
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>
|
||||
|
||||
Reference in New Issue
Block a user