mirror of
https://github.com/langgenius/dify.git
synced 2026-02-28 04:15:10 +00:00
Compare commits
6 Commits
2-25-vinex
...
verify-ema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8d410d7fe | ||
|
|
7ea09fe49d | ||
|
|
176bab84a9 | ||
|
|
ebe7548c72 | ||
|
|
fce50cf12c | ||
|
|
9a71d70738 |
9
.github/workflows/pyrefly-diff-comment.yml
vendored
9
.github/workflows/pyrefly-diff-comment.yml
vendored
@@ -77,7 +77,14 @@ jobs:
|
||||
}
|
||||
|
||||
const body = diff.trim()
|
||||
? '### Pyrefly Diff\n<details>\n<summary>base → PR</summary>\n\n```diff\n' + diff + '\n```\n</details>'
|
||||
? `### Pyrefly Diff
|
||||
<details>
|
||||
<summary>base → PR</summary>
|
||||
|
||||
\`\`\`diff
|
||||
${diff}
|
||||
\`\`\`
|
||||
</details>`
|
||||
: '### Pyrefly Diff\nNo changes detected.';
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
|
||||
63
.github/workflows/web-tests.yml
vendored
63
.github/workflows/web-tests.yml
vendored
@@ -3,22 +3,14 @@ name: Web Tests
|
||||
on:
|
||||
workflow_call:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: web-tests-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }})
|
||||
name: Web Tests
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -47,58 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: blob-report-${{ matrix.shardIndex }}
|
||||
path: web/.vitest-reports/*
|
||||
include-hidden-files: true
|
||||
retention-days: 1
|
||||
|
||||
merge-reports:
|
||||
name: Merge Test Reports
|
||||
if: ${{ !cancelled() }}
|
||||
needs: [test]
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: ./web
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: pnpm vitest --merge-reports --coverage --silent=passed-only
|
||||
run: pnpm test:ci
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -50,6 +50,7 @@ forbidden_modules =
|
||||
allow_indirect_imports = True
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
@@ -105,6 +106,9 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.model_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.provider_manager
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> models.model
|
||||
core.workflow.nodes.datasource.datasource_node -> models.tools
|
||||
core.workflow.nodes.datasource.datasource_node -> services.datasource_provider_service
|
||||
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
@@ -142,6 +146,8 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.apps.exc
|
||||
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@@ -154,6 +160,7 @@ ignore_imports =
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider
|
||||
core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider
|
||||
core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.variables
|
||||
core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy
|
||||
core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy
|
||||
@@ -190,6 +197,7 @@ ignore_imports =
|
||||
core.workflow.nodes.code.code_node -> core.variables.segments
|
||||
core.workflow.nodes.code.code_node -> core.variables.types
|
||||
core.workflow.nodes.code.entities -> core.variables.types
|
||||
core.workflow.nodes.datasource.datasource_node -> core.variables.segments
|
||||
core.workflow.nodes.document_extractor.node -> core.variables
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
@@ -232,6 +240,7 @@ ignore_imports =
|
||||
core.workflow.variable_loader -> core.variables.consts
|
||||
core.workflow.workflow_type_encoder -> core.variables
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
|
||||
@@ -42,7 +42,7 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Start the worker service (async and scheduler tasks, runs from `api`).
|
||||
1. Optional: start the worker service (async tasks, runs from `api`).
|
||||
|
||||
```bash
|
||||
./dev/start-worker
|
||||
@@ -54,6 +54,86 @@ The scripts resolve paths relative to their location, so you can run them from a
|
||||
./dev/start-beat
|
||||
```
|
||||
|
||||
### Manual commands
|
||||
|
||||
<details>
|
||||
<summary>Show manual setup and run steps</summary>
|
||||
|
||||
These commands assume you start from the repository root.
|
||||
|
||||
1. Start the docker-compose stack.
|
||||
|
||||
The backend requires middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`.
|
||||
|
||||
```bash
|
||||
cp docker/middleware.env.example docker/middleware.env
|
||||
# Use mysql or another vector database profile if you are not using postgres/weaviate.
|
||||
docker compose -f docker/docker-compose.middleware.yaml --profile postgresql --profile weaviate -p dify up -d
|
||||
```
|
||||
|
||||
1. Copy env files.
|
||||
|
||||
```bash
|
||||
cp api/.env.example api/.env
|
||||
cp web/.env.example web/.env.local
|
||||
```
|
||||
|
||||
1. Install UV if needed.
|
||||
|
||||
```bash
|
||||
pip install uv
|
||||
# Or on macOS
|
||||
brew install uv
|
||||
```
|
||||
|
||||
1. Install API dependencies.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv sync --group dev
|
||||
```
|
||||
|
||||
1. Install web dependencies.
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm install
|
||||
cd ..
|
||||
```
|
||||
|
||||
1. Start backend (runs migrations first, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run flask db upgrade
|
||||
uv run flask run --host 0.0.0.0 --port=5001 --debug
|
||||
```
|
||||
|
||||
1. Start Dify [web](../web) service (in a new terminal).
|
||||
|
||||
```bash
|
||||
cd web
|
||||
pnpm dev:inspect
|
||||
```
|
||||
|
||||
1. Set up your application by visiting `http://localhost:3000`.
|
||||
|
||||
1. Optional: start the worker service (async tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery beat
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Environment notes
|
||||
|
||||
> [!IMPORTANT]
|
||||
|
||||
@@ -133,7 +133,6 @@ class EducationAutocompleteQuery(BaseModel):
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
|
||||
@@ -547,13 +546,17 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.token is not None:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@@ -573,7 +576,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -608,12 +611,26 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
phase_transitions: dict[str, str] = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@@ -643,13 +660,22 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@@ -122,7 +122,7 @@ class AppQueueManager(ABC):
|
||||
"""Attach the live graph runtime state reference for downstream consumers."""
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
|
||||
"""
|
||||
Publish event to queue
|
||||
:param event:
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
@@ -19,7 +18,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor
|
||||
from core.workflow.nodes.code.entities import CodeLanguage
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
@@ -180,15 +178,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
return DatasourceNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
datasource_manager=DatasourceManager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
|
||||
@@ -1,39 +1,16 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from threading import Lock
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
import contexts
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
from core.db.session_factory import session_factory
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -126,238 +103,3 @@ class DatasourceManager:
|
||||
tenant_id,
|
||||
datasource_type,
|
||||
).get_datasource(datasource_name)
|
||||
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str:
|
||||
datasource_runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
return datasource_runtime.get_icon_url(tenant_id)
|
||||
|
||||
@classmethod
|
||||
def stream_online_results(
|
||||
cls,
|
||||
*,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[DatasourceMessage, None, Any]:
|
||||
"""
|
||||
Pull-based streaming of domain messages from datasource plugins.
|
||||
Returns a generator that yields DatasourceMessage and finally returns a minimal final payload.
|
||||
Only ONLINE_DOCUMENT and ONLINE_DRIVE are streamable here; other types are handled by nodes directly.
|
||||
"""
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
runtime = cls.get_datasource_runtime(
|
||||
provider_id=provider_id,
|
||||
datasource_name=datasource_name,
|
||||
tenant_id=tenant_id,
|
||||
datasource_type=ds_type,
|
||||
)
|
||||
|
||||
dsp_service = DatasourceProviderService()
|
||||
credentials = dsp_service.get_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
doc_runtime = cast(OnlineDocumentDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
doc_runtime.runtime.credentials = credentials
|
||||
if datasource_param is None:
|
||||
raise ValueError("datasource_param is required for ONLINE_DOCUMENT streaming")
|
||||
inner_gen: Generator[DatasourceMessage, None, None] = doc_runtime.get_online_document_page_content(
|
||||
user_id=user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_param.workspace_id,
|
||||
page_id=datasource_param.page_id,
|
||||
type=datasource_param.type,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
elif ds_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
drive_runtime = cast(OnlineDriveDatasourcePlugin, runtime)
|
||||
if credentials:
|
||||
drive_runtime.runtime.credentials = credentials
|
||||
if online_drive_request is None:
|
||||
raise ValueError("online_drive_request is required for ONLINE_DRIVE streaming")
|
||||
inner_gen = drive_runtime.online_drive_download_file(
|
||||
user_id=user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=online_drive_request.id,
|
||||
bucket=online_drive_request.bucket,
|
||||
),
|
||||
provider_type=ds_type,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type for streaming: {ds_type}")
|
||||
|
||||
# Bridge through to caller while preserving generator return contract
|
||||
yield from inner_gen
|
||||
# No structured final data here; node/adapter will assemble outputs
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]:
|
||||
ds_type = DatasourceProviderType.value_of(datasource_type)
|
||||
|
||||
messages = cls.stream_online_results(
|
||||
user_id=user_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_type=datasource_type,
|
||||
provider_id=provider_id,
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
credential_id=credential_id,
|
||||
datasource_param=datasource_param,
|
||||
online_drive_request=online_drive_request,
|
||||
)
|
||||
|
||||
transformed = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages, user_id=user_id, tenant_id=tenant_id, conversation_id=None
|
||||
)
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
file_out: File | None = None
|
||||
|
||||
for message in transformed:
|
||||
mtype = message.type
|
||||
if mtype in {
|
||||
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
DatasourceMessage.MessageType.BINARY_LINK,
|
||||
DatasourceMessage.MessageType.IMAGE,
|
||||
}:
|
||||
wanted_ds_type = ds_type in {
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
}
|
||||
if wanted_ds_type and isinstance(message.message, DatasourceMessage.TextMessage):
|
||||
url = message.message.text
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
with session_factory.create_session() as session:
|
||||
stmt = select(ToolFile).where(
|
||||
ToolFile.id == datasource_file_id, ToolFile.tenant_id == tenant_id
|
||||
)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if not datasource_file:
|
||||
raise ValueError(
|
||||
f"ToolFile not found for file_id={datasource_file_id}, tenant_id={tenant_id}"
|
||||
)
|
||||
mime_type = datasource_file.mimetype
|
||||
if datasource_file is not None:
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(mime_type),
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"url": url,
|
||||
}
|
||||
file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id)
|
||||
elif mtype == DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False)
|
||||
elif mtype == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"], chunk=f"Link: {message.message.text}\n", is_final=False
|
||||
)
|
||||
elif mtype == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
name = message.message.variable_name
|
||||
value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
assert isinstance(value, str), "stream variable_value must be str"
|
||||
variables[name] = variables.get(name, "") + value
|
||||
yield StreamChunkEvent(selector=[node_id, name], chunk=value, is_final=False)
|
||||
else:
|
||||
variables[name] = value
|
||||
elif mtype == DatasourceMessage.MessageType.FILE:
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and message.meta:
|
||||
f = message.meta.get("file")
|
||||
if isinstance(f, File):
|
||||
file_out = f
|
||||
else:
|
||||
pass
|
||||
|
||||
yield StreamChunkEvent(selector=[node_id, "text"], chunk="", is_final=True)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DRIVE and file_out is not None:
|
||||
variable_pool.add([node_id, "file"], file_out)
|
||||
|
||||
if ds_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={**variables},
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file_out,
|
||||
"datasource_type": ds_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
return file_info
|
||||
|
||||
@@ -379,11 +379,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
||||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: str = Field("", description="The name of the bucket")
|
||||
|
||||
@field_validator("bucket", mode="before")
|
||||
@classmethod
|
||||
def _coerce_bucket(cls, v) -> str:
|
||||
if v is None:
|
||||
return ""
|
||||
return str(v)
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
|
||||
@@ -1,26 +1,40 @@
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any, cast
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.file import File
|
||||
from core.workflow.file.enums import FileTransferMethod, FileType
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.repositories.datasource_manager_protocol import (
|
||||
DatasourceManagerProtocol,
|
||||
DatasourceParameter,
|
||||
OnlineDriveDownloadFileParam,
|
||||
)
|
||||
from core.workflow.nodes.tool.exc import ToolFileError
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
|
||||
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from .entities import DatasourceNodeData
|
||||
from .exc import DatasourceNodeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||
|
||||
|
||||
class DatasourceNode(Node[DatasourceNodeData]):
|
||||
@@ -31,22 +45,6 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
node_type = NodeType.DATASOURCE
|
||||
execution_type = NodeExecutionType.ROOT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
datasource_manager: DatasourceManagerProtocol,
|
||||
):
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self.datasource_manager = datasource_manager
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the datasource node
|
||||
@@ -54,69 +52,84 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
node_data = self.node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segment:
|
||||
datasource_type_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE])
|
||||
if not datasource_type_segement:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None
|
||||
datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segment:
|
||||
datasource_type = str(datasource_type_segement.value) if datasource_type_segement.value else None
|
||||
datasource_info_segement = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO])
|
||||
if not datasource_info_segement:
|
||||
raise DatasourceNodeError("Datasource info is not set")
|
||||
datasource_info_value = datasource_info_segment.value
|
||||
datasource_info_value = datasource_info_segement.value
|
||||
if not isinstance(datasource_info_value, dict):
|
||||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_type = DatasourceProviderType.value_of(datasource_type)
|
||||
provider_id = f"{node_data.plugin_id}/{node_data.provider_name}"
|
||||
|
||||
datasource_info["icon"] = self.datasource_manager.get_icon_url(
|
||||
provider_id=provider_id,
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=datasource_type.value,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
try:
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credentials = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=datasource_info.get("credential_id", ""),
|
||||
)
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT | DatasourceProviderType.ONLINE_DRIVE:
|
||||
# Build typed request objects
|
||||
datasource_parameters = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_parameters = DatasourceParameter(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.get_online_document_page_content(
|
||||
user_id=self.user_id,
|
||||
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||
workspace_id=datasource_info.get("workspace_id", ""),
|
||||
page_id=datasource_info.get("page", {}).get("page_id", ""),
|
||||
type=datasource_info.get("page", {}).get("type", ""),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
|
||||
online_drive_request = None
|
||||
if datasource_type == DatasourceProviderType.ONLINE_DRIVE:
|
||||
online_drive_request = OnlineDriveDownloadFileParam(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket", ""),
|
||||
)
|
||||
yield from self._transform_message(
|
||||
messages=online_document_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
)
|
||||
case DatasourceProviderType.ONLINE_DRIVE:
|
||||
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||
if credentials:
|
||||
datasource_runtime.runtime.credentials = credentials
|
||||
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||
datasource_runtime.online_drive_download_file(
|
||||
user_id=self.user_id,
|
||||
request=OnlineDriveDownloadFileRequest(
|
||||
id=datasource_info.get("id", ""),
|
||||
bucket=datasource_info.get("bucket"),
|
||||
),
|
||||
provider_type=datasource_type,
|
||||
)
|
||||
|
||||
credential_id = datasource_info.get("credential_id", "")
|
||||
|
||||
yield from self.datasource_manager.stream_node_events(
|
||||
node_id=self._node_id,
|
||||
user_id=self.user_id,
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
datasource_type=datasource_type.value,
|
||||
provider_id=provider_id,
|
||||
tenant_id=self.tenant_id,
|
||||
provider=node_data.provider_name,
|
||||
plugin_id=node_data.plugin_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
yield from self._transform_datasource_file_message(
|
||||
messages=online_drive_result,
|
||||
parameters_for_log=parameters_for_log,
|
||||
datasource_info=datasource_info,
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=datasource_parameters,
|
||||
online_drive_request=online_drive_request,
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||
yield StreamCompletedEvent(
|
||||
@@ -134,9 +147,23 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
file_info = self.datasource_manager.get_upload_file_by_id(
|
||||
file_id=related_id, tenant_id=self.tenant_id
|
||||
file_info = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.CUSTOM,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
url=upload_file.source_url,
|
||||
)
|
||||
variable_pool.add([self._node_id, "file"], file_info)
|
||||
# variable_pool.add([self.node_id, "file"], file_info.to_dict())
|
||||
@@ -174,6 +201,55 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
datasource_parameters: Sequence[DatasourceParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: DatasourceNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
if node_data.datasource_parameters:
|
||||
for parameter_name in node_data.datasource_parameters:
|
||||
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||
if datasource_input.type == "variable":
|
||||
variable = variable_pool.get(datasource_input.value)
|
||||
if variable is None:
|
||||
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif datasource_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
@@ -211,6 +287,206 @@ class DatasourceNode(Node[DatasourceNodeData]):
|
||||
|
||||
return result
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict | list] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
match message.type:
|
||||
case (
|
||||
DatasourceMessage.MessageType.IMAGE_LINK
|
||||
| DatasourceMessage.MessageType.BINARY_LINK
|
||||
| DatasourceMessage.MessageType.IMAGE
|
||||
):
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
case DatasourceMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
)
|
||||
case DatasourceMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
json.append(message.message.json_object)
|
||||
case DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
case DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
case DatasourceMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
files.append(message.meta["file"])
|
||||
case (
|
||||
DatasourceMessage.MessageType.BLOB_CHUNK
|
||||
| DatasourceMessage.MessageType.LOG
|
||||
| DatasourceMessage.MessageType.RETRIEVER_RESOURCES
|
||||
):
|
||||
pass
|
||||
|
||||
# mark the end of the stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={**variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _transform_datasource_file_message(
|
||||
self,
|
||||
messages: Generator[DatasourceMessage, None, None],
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
datasource_type: DatasourceProviderType,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
file = None
|
||||
for message in message_stream:
|
||||
if message.type == DatasourceMessage.MessageType.BINARY_LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == datasource_file_id)
|
||||
datasource_file = session.scalar(stmt)
|
||||
if datasource_file is None:
|
||||
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": datasource_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
if file:
|
||||
variable_pool.add([self._node_id, "file"], file)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
outputs={
|
||||
"file": file,
|
||||
"datasource_type": datasource_type,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.file import File
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
class DatasourceParameter(BaseModel):
|
||||
workspace_id: str
|
||||
page_id: str
|
||||
type: str
|
||||
|
||||
|
||||
class OnlineDriveDownloadFileParam(BaseModel):
|
||||
id: str
|
||||
bucket: str
|
||||
|
||||
|
||||
class DatasourceFinal(BaseModel):
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasourceManagerProtocol(Protocol):
|
||||
@classmethod
|
||||
def get_icon_url(cls, provider_id: str, tenant_id: str, datasource_name: str, datasource_type: str) -> str: ...
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(
|
||||
cls,
|
||||
*,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
datasource_name: str,
|
||||
datasource_type: str,
|
||||
provider_id: str,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
credential_id: str,
|
||||
parameters_for_log: dict[str, Any],
|
||||
datasource_info: dict[str, Any],
|
||||
variable_pool: Any,
|
||||
datasource_param: DatasourceParameter | None = None,
|
||||
online_drive_request: OnlineDriveDownloadFileParam | None = None,
|
||||
) -> Generator[StreamChunkEvent | StreamCompletedEvent, None, None]: ...
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: ...
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -80,12 +81,25 @@ class TokenPair(BaseModel):
|
||||
csrf_token: str
|
||||
|
||||
|
||||
class ChangeEmailPhase(StrEnum):
|
||||
OLD = "old_email"
|
||||
OLD_VERIFIED = "old_email_verified"
|
||||
NEW = "new_email"
|
||||
NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
|
||||
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@@ -535,13 +549,20 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
|
||||
|
||||
def _gen_var_stream() -> Generator[DatasourceMessage, None, None]:
|
||||
# produce a streamed variable "a"="xy"
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="x", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="a", variable_value="y", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_accumulates_variables(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_var_stream())
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="A",
|
||||
user_id="u",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={"user_id": "u"},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=type("P", (), {"workspace_id": "w", "page_id": "pg", "type": "t"})(),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
@@ -1,84 +0,0 @@
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamCompletedEvent
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _Seg:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
|
||||
|
||||
class _VarPool:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def get(self, path):
|
||||
d = self.data
|
||||
for k in path:
|
||||
d = d[k]
|
||||
return _Seg(d)
|
||||
|
||||
def add(self, *_a, **_k):
|
||||
pass
|
||||
|
||||
|
||||
class _GS:
|
||||
def __init__(self, vp):
|
||||
self.variable_pool = vp
|
||||
|
||||
|
||||
class _GP:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_node_integration_minimal_stream(mocker):
|
||||
sys_d = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info": {"workspace_id": "w", "page": {"page_id": "pg", "type": "t"}, "credential_id": ""},
|
||||
}
|
||||
}
|
||||
vp = _VarPool(sys_d)
|
||||
|
||||
class _Mgr:
|
||||
@classmethod
|
||||
def get_icon_url(cls, **_):
|
||||
return "icon"
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(cls, **_):
|
||||
yield from ()
|
||||
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
graph_init_params=_GP(),
|
||||
graph_runtime_state=_GS(vp),
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
out = list(node._run())
|
||||
assert isinstance(out[-1], StreamCompletedEvent)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,418 +0,0 @@
|
||||
"""Integration tests for SQL-oriented DatasetService scenarios.
|
||||
|
||||
This suite migrates SQL-backed behaviors from the old unit suite to real
|
||||
container-backed integration tests. The tests exercise real ORM persistence and
|
||||
only patch non-DB collaborators when needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel
|
||||
from services.errors.dataset import DatasetNameDuplicateError
|
||||
|
||||
|
||||
class DatasetServiceIntegrationDataFactory:
|
||||
"""Factory for creating real database entities used by integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||
"""Create an account and tenant, then bind the account as current tenant member."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add_all([account, tenant])
|
||||
db.session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.flush()
|
||||
|
||||
# Keep tenant context on the in-memory user without opening a separate session.
|
||||
account.role = role
|
||||
account._current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str = "Test Dataset",
|
||||
description: str | None = "Test description",
|
||||
provider: str = "vendor",
|
||||
indexing_technique: str | None = "high_quality",
|
||||
permission: str = DatasetPermissionEnum.ONLY_ME,
|
||||
retrieval_model: dict | None = None,
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
chunk_structure: str | None = None,
|
||||
) -> Dataset:
|
||||
"""Create a dataset record with configurable SQL fields."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
permission=permission,
|
||||
retrieval_model=retrieval_model,
|
||||
embedding_model_provider=embedding_model_provider,
|
||||
embedding_model=embedding_model,
|
||||
collection_binding_id=collection_binding_id,
|
||||
chunk_structure=chunk_structure,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document:
|
||||
"""Create a document row belonging to the given dataset."""
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=1,
|
||||
data_source_type="upload_file",
|
||||
data_source_info='{"upload_file_id": "upload-file-id"}',
|
||||
batch=str(uuid4()),
|
||||
name=name,
|
||||
created_from="web",
|
||||
created_by=created_by,
|
||||
indexing_status="completed",
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model(provider: str = "openai", model_name: str = "text-embedding-ada-002") -> Mock:
|
||||
"""Create a fake embedding model object for external provider boundary patching."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.provider = provider
|
||||
embedding_model.model_name = model_name
|
||||
return embedding_model
|
||||
|
||||
|
||||
class TestDatasetServiceCreateDataset:
|
||||
"""Integration coverage for DatasetService.create_empty_dataset."""
|
||||
|
||||
def test_create_internal_dataset_basic_success(self, db_session_with_containers):
|
||||
"""Create a basic internal dataset with minimal configuration."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Basic Internal Dataset",
|
||||
description="Test description",
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_dataset = db.session.get(Dataset, result.id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.provider == "vendor"
|
||||
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert created_dataset.embedding_model_provider is None
|
||||
assert created_dataset.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers):
|
||||
"""Create an internal dataset with economy indexing and no embedding model."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Economy Dataset",
|
||||
description=None,
|
||||
indexing_technique="economy",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers):
|
||||
"""Create a high-quality dataset and persist embedding model settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.ModelManager") as mock_model_manager:
|
||||
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
|
||||
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="High Quality Dataset",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager.return_value.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
def test_create_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
"""Raise duplicate-name error when the same tenant already has the name."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Duplicate Dataset",
|
||||
indexing_technique=None,
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Duplicate Dataset",
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
)
|
||||
|
||||
def test_create_external_dataset_success(self, db_session_with_containers):
|
||||
"""Create an external dataset and persist external knowledge binding."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
external_knowledge_id = "knowledge-123"
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api:
|
||||
mock_get_api.return_value = Mock(id=external_knowledge_api_id)
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="External Dataset",
|
||||
description=None,
|
||||
indexing_technique=None,
|
||||
account=account,
|
||||
provider="external",
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
|
||||
assert result.provider == "external"
|
||||
assert binding is not None
|
||||
assert binding.external_knowledge_id == external_knowledge_id
|
||||
assert binding.external_knowledge_api_id == external_knowledge_api_id
|
||||
|
||||
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers):
|
||||
"""Create a high-quality dataset with retrieval/reranking settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=True,
|
||||
reranking_model=RerankingModel(
|
||||
reranking_provider_name="cohere",
|
||||
reranking_model_name="rerank-english-v2.0",
|
||||
),
|
||||
top_k=3,
|
||||
score_threshold_enabled=True,
|
||||
score_threshold=0.6,
|
||||
)
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking,
|
||||
):
|
||||
mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model
|
||||
|
||||
result = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant.id,
|
||||
name="Dataset With Reranking",
|
||||
description=None,
|
||||
indexing_technique="high_quality",
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
assert result.retrieval_model == retrieval_model.model_dump()
|
||||
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
"""Integration coverage for SQL-backed update and delete behavior."""
|
||||
|
||||
def test_update_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
"""Reject update when target name already exists within the same tenant."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Source Dataset",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Existing Dataset",
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Dataset name already exists"):
|
||||
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
|
||||
"""Delete a dataset that already has documents."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers):
|
||||
"""Delete a dataset that has no documents and no indexing technique."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure=None,
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
|
||||
"""Delete dataset when indexing_technique is None but doc_form path still exists."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
result = DatasetService.delete_dataset(dataset.id, account)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
|
||||
class TestDatasetServiceRetrievalConfiguration:
|
||||
"""Integration coverage for retrieval configuration persistence."""
|
||||
|
||||
def test_get_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
"""Return retrieval configuration that is persisted in SQL."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 5,
|
||||
"score_threshold": 0.5,
|
||||
"reranking_enable": True,
|
||||
}
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
retrieval_model=retrieval_model,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset.id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.retrieval_model == retrieval_model
|
||||
assert result.retrieval_model["search_method"] == "semantic_search"
|
||||
assert result.retrieval_model["top_k"] == 5
|
||||
|
||||
def test_update_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
"""Persist retrieval configuration updates through DatasetService.update_dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
retrieval_model={"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0},
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=str(uuid4()),
|
||||
)
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": {
|
||||
"search_method": "full_text_search",
|
||||
"top_k": 10,
|
||||
"score_threshold": 0.7,
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, account)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(dataset)
|
||||
assert result.id == dataset.id
|
||||
assert dataset.retrieval_model == update_data["retrieval_model"]
|
||||
@@ -1,464 +0,0 @@
|
||||
"""
|
||||
Integration tests for document_indexing_sync_task using testcontainers.
|
||||
|
||||
This module validates SQL-backed behavior for document sync flows:
|
||||
- Notion sync precondition checks
|
||||
- Segment cleanup and document state updates
|
||||
- Credential and indexing error handling
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_dict_adapter_for_psycopg2():
|
||||
"""Align test DB adapter behavior with dict payloads used in task update flow."""
|
||||
register_adapter(dict, Json)
|
||||
|
||||
|
||||
class DocumentIndexingSyncTaskTestDataFactory:
|
||||
"""Create real DB entities for document indexing sync integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]:
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(db_session_with_containers, tenant_id: str, created_by: str) -> Dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=f"dataset-{uuid4()}",
|
||||
description="sync test dataset",
|
||||
data_source_type="notion_import",
|
||||
indexing_technique="high_quality",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
data_source_info: dict | None,
|
||||
indexing_status: str = "completed",
|
||||
) -> Document:
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
position=0,
|
||||
data_source_type="notion_import",
|
||||
data_source_info=json.dumps(data_source_info) if data_source_info is not None else None,
|
||||
batch="test-batch",
|
||||
name=f"doc-{uuid4()}",
|
||||
created_from="notion_import",
|
||||
created_by=created_by,
|
||||
indexing_status=indexing_status,
|
||||
enabled=True,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_segments(
|
||||
db_session_with_containers,
|
||||
*,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
created_by: str,
|
||||
count: int = 3,
|
||||
) -> list[DocumentSegment]:
|
||||
segments: list[DocumentSegment] = []
|
||||
for i in range(count):
|
||||
segment = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
document_id=document_id,
|
||||
position=i,
|
||||
content=f"segment-{i}",
|
||||
answer=None,
|
||||
word_count=10,
|
||||
tokens=5,
|
||||
index_node_id=f"node-{document_id}-{i}",
|
||||
status="completed",
|
||||
created_by=created_by,
|
||||
)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Integration tests for document_indexing_sync_task with real database assertions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_dependencies(self):
|
||||
"""Patch only external collaborators; keep DB access real."""
|
||||
with (
|
||||
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_datasource_service_class,
|
||||
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_notion_extractor_class,
|
||||
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_indexing_runner_class,
|
||||
):
|
||||
datasource_service = Mock()
|
||||
datasource_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_datasource_service_class.return_value = datasource_service
|
||||
|
||||
notion_extractor = Mock()
|
||||
notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_notion_extractor_class.return_value = notion_extractor
|
||||
|
||||
index_processor = Mock()
|
||||
index_processor.clean = Mock()
|
||||
mock_index_processor_factory.return_value.init_index_processor.return_value = index_processor
|
||||
|
||||
indexing_runner = Mock(spec=IndexingRunner)
|
||||
indexing_runner.run = Mock()
|
||||
mock_indexing_runner_class.return_value = indexing_runner
|
||||
|
||||
yield {
|
||||
"datasource_service": datasource_service,
|
||||
"notion_extractor": notion_extractor,
|
||||
"notion_extractor_class": mock_notion_extractor_class,
|
||||
"index_processor": index_processor,
|
||||
"index_processor_factory": mock_index_processor_factory,
|
||||
"indexing_runner": indexing_runner,
|
||||
}
|
||||
|
||||
def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None):
|
||||
account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
notion_info = data_source_info or {
|
||||
"notion_workspace_id": str(uuid4()),
|
||||
"notion_page_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": str(uuid4()),
|
||||
}
|
||||
|
||||
document = DocumentIndexingSyncTaskTestDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=account.id,
|
||||
data_source_info=notion_info,
|
||||
indexing_status="completed",
|
||||
)
|
||||
|
||||
segments = DocumentIndexingSyncTaskTestDataFactory.create_segments(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
created_by=account.id,
|
||||
count=3,
|
||||
)
|
||||
|
||||
return {
|
||||
"account": account,
|
||||
"tenant": tenant,
|
||||
"dataset": dataset,
|
||||
"document": document,
|
||||
"segments": segments,
|
||||
"node_ids": [segment.index_node_id for segment in segments],
|
||||
"notion_info": notion_info,
|
||||
}
|
||||
|
||||
def test_document_not_found(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task handles missing document gracefully."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
document_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(
|
||||
db_session_with_containers,
|
||||
data_source_info={
|
||||
"notion_page_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(
|
||||
db_session_with_containers,
|
||||
data_source_info={
|
||||
"notion_workspace_id": str(uuid4()),
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
},
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None)
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).update(
|
||||
{"data_source_info": None}
|
||||
)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task sets document error state when credential is missing."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["datasource_service"].get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task exits early when notion page is unchanged."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["notion_extractor"].get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "completed"
|
||||
assert updated_document.processing_started_at is None
|
||||
assert remaining_segments == 3
|
||||
mock_external_dependencies["index_processor"].clean.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_not_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test full successful sync flow with SQL state updates and side effects."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.processing_started_at is not None
|
||||
assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z"
|
||||
assert remaining_segments == 0
|
||||
|
||||
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
|
||||
assert clean_call_args is not None
|
||||
clean_args, clean_kwargs = clean_call_args
|
||||
assert getattr(clean_args[0], "id", None) == context["dataset"].id
|
||||
assert set(clean_args[1]) == set(context["node_ids"])
|
||||
assert clean_kwargs.get("with_keywords") is True
|
||||
assert clean_kwargs.get("delete_child_chunks") is True
|
||||
|
||||
run_call_args = mock_external_dependencies["indexing_runner"].run.call_args
|
||||
assert run_call_args is not None
|
||||
run_documents = run_call_args[0][0]
|
||||
assert len(run_documents) == 1
|
||||
assert getattr(run_documents[0], "id", None) == context["document"].id
|
||||
|
||||
def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that task still updates document and reindexes if dataset vanishes before clean."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
def _delete_dataset_before_clean() -> str:
|
||||
db_session_with_containers.query(Dataset).where(Dataset.id == context["dataset"].id).delete()
|
||||
db_session_with_containers.commit()
|
||||
return "2024-01-02T00:00:00Z"
|
||||
|
||||
mock_external_dependencies[
|
||||
"notion_extractor"
|
||||
].get_notion_last_edited_time.side_effect = _delete_dataset_before_clean
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
mock_external_dependencies["index_processor"].clean.assert_not_called()
|
||||
mock_external_dependencies["indexing_runner"].run.assert_called_once()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that indexing continues when index cleanup fails."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["index_processor"].clean.side_effect = Exception("Cleaning error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
remaining_segments = (
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == context["document"].id)
|
||||
.count()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert remaining_segments == 0
|
||||
mock_external_dependencies["indexing_runner"].run.assert_called_once()
|
||||
|
||||
def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that DocumentIsPausedError does not flip document into error state."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["indexing_runner"].run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
assert updated_document.error is None
|
||||
|
||||
def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies):
|
||||
"""Test that indexing errors are persisted to document state."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
mock_external_dependencies["indexing_runner"].run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
db_session_with_containers.expire_all()
|
||||
updated_document = (
|
||||
db_session_with_containers.query(Document).where(Document.id == context["document"].id).first()
|
||||
)
|
||||
assert updated_document is not None
|
||||
assert updated_document.indexing_status == "error"
|
||||
assert "Indexing error" in updated_document.error
|
||||
assert updated_document.stopped_at is not None
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
mock_external_dependencies,
|
||||
):
|
||||
"""Test that clean is called with dataset instance and collected node ids."""
|
||||
# Arrange
|
||||
context = self._create_notion_sync_context(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(context["dataset"].id, context["document"].id)
|
||||
|
||||
# Assert
|
||||
clean_call_args = mock_external_dependencies["index_processor"].clean.call_args
|
||||
assert clean_call_args is not None
|
||||
clean_args, clean_kwargs = clean_call_args
|
||||
assert getattr(clean_args[0], "id", None) == context["dataset"].id
|
||||
assert set(clean_args[1]) == set(context["node_ids"])
|
||||
assert clean_kwargs.get("with_keywords") is True
|
||||
assert clean_kwargs.get("delete_child_chunks") is True
|
||||
@@ -77,7 +77,7 @@ def _restx_mask_defaults(app: Flask):
|
||||
|
||||
|
||||
def test_code_based_extension_get_returns_service_data(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
service_result = [{"entrypoint": "main:agent"}]
|
||||
service_result = {"entrypoint": "main:agent"}
|
||||
service_mock = MagicMock(return_value=service_result)
|
||||
monkeypatch.setattr(
|
||||
"controllers.console.extension.CodeBasedExtensionService.get_code_based_extension",
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
def test_should_infer_new_email_phase_from_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.db")
|
||||
@patch("controllers.console.workspace.account.Session")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account_by_email,
|
||||
mock_session_cls,
|
||||
mock_account_db,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
|
||||
existing_account = _build_account("old@example.com", "acc-old")
|
||||
mock_get_account_by_email.return_value = existing_account
|
||||
mock_send_email.return_value = "token-legacy"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__.return_value = mock_session
|
||||
mock_session_cm.__exit__.return_value = None
|
||||
mock_session_cls.return_value = mock_session_cm
|
||||
mock_account_db.engine = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-legacy"}
|
||||
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=existing_account,
|
||||
email="old@example.com",
|
||||
old_email="old@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_unverified_old_email_token_for_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_refresh_new_email_phase_to_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("old@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "5678",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-phase-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("new@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-456")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="5678",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "new@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_phase_token_for_reset(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_mismatched_new_email_for_verified_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "another@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-789"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(text=text),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
|
||||
def test_get_icon_url_calls_runtime(mocker):
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_icon_url.return_value = "https://icon"
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
|
||||
url = DatasourceManager.get_icon_url(
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
)
|
||||
assert url == "https://icon"
|
||||
DatasourceManager.get_datasource_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
# stub runtime to yield a text message
|
||||
def _doc_messages(**_):
|
||||
yield from _gen_messages_text_only("hello")
|
||||
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_online_document_page_content.side_effect = _doc_messages
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=fake_runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
msgs = list(gen)
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0].message.text == "hello"
|
||||
|
||||
|
||||
def test_stream_node_events_emits_events_online_document(mocker):
|
||||
# make manager's low-level stream produce TEXT only
|
||||
mocker.patch.object(
|
||||
DatasourceManager,
|
||||
"stream_online_results",
|
||||
return_value=_gen_messages_text_only("hello"),
|
||||
)
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={"k": "v"},
|
||||
datasource_info={"user_id": "u1"},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
# should contain one StreamChunkEvent then a final chunk (empty) and a completed event
|
||||
assert isinstance(events[0], StreamChunkEvent)
|
||||
assert events[0].chunk == "hello"
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_builds_file(mocker):
|
||||
# fake UploadFile row
|
||||
fake_row = types.SimpleNamespace(
|
||||
id="fid",
|
||||
name="f",
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
size=1,
|
||||
key="k",
|
||||
source_url="http://x",
|
||||
)
|
||||
|
||||
class _Q:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._row
|
||||
|
||||
class _S:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q(self._row)
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
||||
|
||||
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
assert f.related_id == "fid"
|
||||
assert f.extension == ".txt"
|
||||
@@ -1,93 +0,0 @@
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.datasource.datasource_node import DatasourceNode
|
||||
|
||||
|
||||
class _VarSeg:
|
||||
def __init__(self, v):
|
||||
self.value = v
|
||||
|
||||
|
||||
class _VarPool:
|
||||
def __init__(self, mapping):
|
||||
self._m = mapping
|
||||
|
||||
def get(self, selector):
|
||||
d = self._m
|
||||
for k in selector:
|
||||
d = d[k]
|
||||
return _VarSeg(d)
|
||||
|
||||
def add(self, *_args, **_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class _GraphState:
|
||||
def __init__(self, var_pool):
|
||||
self.variable_pool = var_pool
|
||||
|
||||
|
||||
class _GraphParams:
|
||||
tenant_id = "t1"
|
||||
app_id = "app-1"
|
||||
workflow_id = "wf-1"
|
||||
graph_config = {}
|
||||
user_id = "u1"
|
||||
user_from = "account"
|
||||
invoke_from = "debugger"
|
||||
call_depth = 0
|
||||
|
||||
|
||||
def test_datasource_node_delegates_to_manager_stream(mocker):
|
||||
# prepare sys variables
|
||||
sys_vars = {
|
||||
"sys": {
|
||||
"datasource_type": "online_document",
|
||||
"datasource_info": {
|
||||
"workspace_id": "w",
|
||||
"page": {"page_id": "pg", "type": "t"},
|
||||
"credential_id": "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var_pool = _VarPool(sys_vars)
|
||||
gs = _GraphState(var_pool)
|
||||
gp = _GraphParams()
|
||||
|
||||
# stub manager class
|
||||
class _Mgr:
|
||||
@classmethod
|
||||
def get_icon_url(cls, **_):
|
||||
return "icon"
|
||||
|
||||
@classmethod
|
||||
def stream_node_events(cls, **_):
|
||||
yield StreamChunkEvent(selector=["n", "text"], chunk="hi", is_final=False)
|
||||
yield StreamCompletedEvent(node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED))
|
||||
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, **_):
|
||||
raise AssertionError("not called")
|
||||
|
||||
node = DatasourceNode(
|
||||
id="n",
|
||||
config={
|
||||
"id": "n",
|
||||
"data": {
|
||||
"type": "datasource",
|
||||
"version": "1",
|
||||
"title": "Datasource",
|
||||
"provider_type": "plugin",
|
||||
"provider_name": "p",
|
||||
"plugin_id": "plug",
|
||||
"datasource_name": "ds",
|
||||
},
|
||||
},
|
||||
graph_init_params=gp,
|
||||
graph_runtime_state=gs,
|
||||
datasource_manager=_Mgr,
|
||||
)
|
||||
|
||||
evts = list(node._run())
|
||||
assert isinstance(evts[0], StreamChunkEvent)
|
||||
assert isinstance(evts[-1], StreamCompletedEvent)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Unit tests for collaborator parameter wiring in document_indexing_sync_task.
|
||||
Unit tests for document indexing sync task.
|
||||
|
||||
These tests intentionally stay in unit scope because they validate call arguments
|
||||
for external collaborators rather than SQL-backed state transitions.
|
||||
This module tests the document indexing sync task functionality including:
|
||||
- Syncing Notion documents when updated
|
||||
- Validating document and data source existence
|
||||
- Credential validation and retrieval
|
||||
- Cleaning old segments before re-indexing
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
|
||||
import uuid
|
||||
@@ -10,92 +14,187 @@ from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.dataset import Dataset, Document
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset_id() -> str:
|
||||
"""Generate a dataset id."""
|
||||
def tenant_id():
|
||||
"""Generate a unique tenant ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_id() -> str:
|
||||
"""Generate a document id."""
|
||||
def dataset_id():
|
||||
"""Generate a unique dataset ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_workspace_id() -> str:
|
||||
"""Generate a notion workspace id."""
|
||||
def document_id():
|
||||
"""Generate a unique document ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def notion_page_id() -> str:
|
||||
"""Generate a notion page id."""
|
||||
def notion_workspace_id():
|
||||
"""Generate a Notion workspace ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def credential_id() -> str:
|
||||
"""Generate a credential id."""
|
||||
def notion_page_id():
|
||||
"""Generate a Notion page ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id):
|
||||
"""Create a minimal dataset mock used by the task pre-check."""
|
||||
def credential_id():
|
||||
"""Generate a credential ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset(dataset_id, tenant_id):
|
||||
"""Create a mock Dataset object."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.indexing_technique = "high_quality"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
return dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a minimal notion document mock for collaborator parameter assertions."""
|
||||
document = Mock(spec=Document)
|
||||
document.id = document_id
|
||||
document.dataset_id = dataset_id
|
||||
document.tenant_id = str(uuid.uuid4())
|
||||
document.data_source_type = "notion_import"
|
||||
document.indexing_status = "completed"
|
||||
document.doc_form = "text_model"
|
||||
document.data_source_info_dict = {
|
||||
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
|
||||
"""Create a mock Document object with Notion data source."""
|
||||
doc = Mock(spec=Document)
|
||||
doc.id = document_id
|
||||
doc.dataset_id = dataset_id
|
||||
doc.tenant_id = tenant_id
|
||||
doc.data_source_type = "notion_import"
|
||||
doc.indexing_status = "completed"
|
||||
doc.error = None
|
||||
doc.stopped_at = None
|
||||
doc.processing_started_at = None
|
||||
doc.doc_form = "text_model"
|
||||
doc.data_source_info_dict = {
|
||||
"notion_workspace_id": notion_workspace_id,
|
||||
"notion_page_id": notion_page_id,
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
"credential_id": credential_id,
|
||||
}
|
||||
return document
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(mock_document, mock_dataset):
|
||||
"""Mock session_factory.create_session to drive deterministic read-only task flow."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory:
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
def mock_document_segments(document_id):
|
||||
"""Create mock DocumentSegment objects."""
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = Mock(spec=DocumentSegment)
|
||||
segment.id = str(uuid.uuid4())
|
||||
segment.document_id = document_id
|
||||
segment.index_node_id = f"node-{document_id}-{i}"
|
||||
segments.append(segment)
|
||||
return segments
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
begin_cm.__exit__.return_value = False
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session().
|
||||
|
||||
mock_session_factory.create_session.return_value = session_cm
|
||||
yield session
|
||||
After session split refactor, the code calls create_session() multiple times.
|
||||
This fixture creates shared query mocks so all sessions use the same
|
||||
query configuration, simulating database persistence across sessions.
|
||||
|
||||
The fixture automatically converts side_effect to cycle to prevent StopIteration.
|
||||
Tests configure mocks the same way as before, but behind the scenes the values
|
||||
are cycled infinitely for all sessions.
|
||||
"""
|
||||
from itertools import cycle
|
||||
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||
sessions = []
|
||||
|
||||
# Shared query mocks - all sessions use these
|
||||
shared_query = MagicMock()
|
||||
shared_filter_by = MagicMock()
|
||||
shared_scalars_result = MagicMock()
|
||||
|
||||
# Create custom first mock that auto-cycles side_effect
|
||||
class CyclicMock(MagicMock):
|
||||
def __setattr__(self, name, value):
|
||||
if name == "side_effect" and value is not None:
|
||||
# Convert list/tuple to infinite cycle
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = cycle(value)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
shared_query.where.return_value.first = CyclicMock()
|
||||
shared_filter_by.first = CyclicMock()
|
||||
|
||||
def _create_session():
|
||||
"""Create a new mock session for each create_session() call."""
|
||||
session = MagicMock()
|
||||
session.close = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
|
||||
# Mock session.begin() context manager
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
|
||||
def _begin_exit_side_effect(exc_type, exc, tb):
|
||||
# commit on success
|
||||
if exc_type is None:
|
||||
session.commit()
|
||||
# return False to propagate exceptions
|
||||
return False
|
||||
|
||||
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
# Mock create_session() context manager
|
||||
cm = MagicMock()
|
||||
cm.__enter__.return_value = session
|
||||
|
||||
def _exit_side_effect(exc_type, exc, tb):
|
||||
session.close()
|
||||
return False
|
||||
|
||||
cm.__exit__.side_effect = _exit_side_effect
|
||||
|
||||
# All sessions use the same shared query mocks
|
||||
session.query.return_value = shared_query
|
||||
shared_query.where.return_value = shared_query
|
||||
shared_query.filter_by.return_value = shared_filter_by
|
||||
session.scalars.return_value = shared_scalars_result
|
||||
|
||||
sessions.append(session)
|
||||
# Attach helpers on the first created session for assertions across all sessions
|
||||
if len(sessions) == 1:
|
||||
session.get_all_sessions = lambda: sessions
|
||||
session.any_close_called = lambda: any(s.close.called for s in sessions)
|
||||
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
|
||||
return cm
|
||||
|
||||
mock_sf.create_session.side_effect = _create_session
|
||||
|
||||
# Create first session and return it
|
||||
_create_session()
|
||||
yield sessions[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock datasource credential provider."""
|
||||
"""Mock DatasourceProviderService."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
@@ -105,16 +204,314 @@ def mock_datasource_provider_service():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock notion extractor class and instance."""
|
||||
"""Mock NotionExtractor."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
yield {"class": mock_extractor_class, "instance": mock_extractor}
|
||||
yield mock_extractor
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
"""Unit tests for collaborator parameter passing in document_indexing_sync_task."""
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
yield mock_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner.run = Mock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for document_indexing_sync_task
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDocumentIndexingSyncTask:
|
||||
"""Tests for the document_indexing_sync_task function."""
|
||||
|
||||
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
|
||||
"""Test that task handles document not found gracefully."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert - at least one session should have been closed
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_workspace_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when notion_page_id is missing."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||
"""Test that task raises error when data_source_info is empty."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = None
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="no notion page found"):
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
def test_credential_not_found(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles missing credentials by updating document status."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
assert mock_document.indexing_status == "error"
|
||||
assert "Datasource credential not found" in mock_document.error
|
||||
assert mock_document.stopped_at is not None
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_page_not_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task does nothing when page has not been updated."""
|
||||
# Arrange
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Return same time as stored in document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document status should remain unchanged
|
||||
assert mock_document.indexing_status == "completed"
|
||||
# At least one session should have been closed via context manager teardown
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_successful_sync_when_page_updated(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test successful sync flow when Notion page has been updated."""
|
||||
# Arrange
|
||||
# Set exact sequence of returns across calls to `.first()`:
|
||||
# 1) document (initial fetch)
|
||||
# 2) dataset (pre-check)
|
||||
# 3) dataset (cleaning phase)
|
||||
# 4) document (pre-indexing update)
|
||||
# 5) document (indexing runner fetch)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
# NotionExtractor returns updated time
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Verify document status was updated to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
assert mock_document.processing_started_at is not None
|
||||
|
||||
# Verify segments were cleaned
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
mock_processor.clean.assert_called_once()
|
||||
|
||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||
# Aggregate execute calls across all created sessions
|
||||
execute_sqls = []
|
||||
for s in mock_db_session.get_all_sessions():
|
||||
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
|
||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||
|
||||
# Verify indexing runner was called
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
|
||||
# Verify session operations (across any created session)
|
||||
assert mock_db_session.any_commit_called()
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_dataset_not_found_during_cleaning(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_indexing_runner,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that task handles dataset not found during cleaning phase."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
None,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Document should still be set to parsing
|
||||
assert mock_document.indexing_status == "parsing"
|
||||
# At least one session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_cleaning_error_continues_to_indexing(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that indexing continues even if cleaning fails."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
# Make the cleaning step fail but not the segment fetch
|
||||
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
processor.clean.side_effect = Exception("Cleaning error")
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Indexing should still be attempted despite cleaning error
|
||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_document_paused_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after handling error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_indexing_runner_general_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that general exceptions during indexing are handled."""
|
||||
# Arrange
|
||||
from itertools import cycle
|
||||
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
# Session should be closed after error
|
||||
assert mock_db_session.any_close_called()
|
||||
|
||||
def test_notion_extractor_initialized_with_correct_params(
|
||||
self,
|
||||
@@ -127,21 +524,27 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
notion_workspace_id,
|
||||
notion_page_id,
|
||||
):
|
||||
"""Test that NotionExtractor is initialized with expected arguments."""
|
||||
"""Test that NotionExtractor is initialized with correct parameters."""
|
||||
# Arrange
|
||||
expected_token = "test_token"
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
# Assert
|
||||
mock_notion_extractor["class"].assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token=expected_token,
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_extractor_class.assert_called_once_with(
|
||||
notion_workspace_id=notion_workspace_id,
|
||||
notion_obj_id=notion_page_id,
|
||||
notion_page_type="page",
|
||||
notion_access_token="test_token",
|
||||
tenant_id=mock_document.tenant_id,
|
||||
)
|
||||
|
||||
def test_datasource_credentials_requested_correctly(
|
||||
self,
|
||||
@@ -153,16 +556,17 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
document_id,
|
||||
credential_id,
|
||||
):
|
||||
"""Test that datasource credentials are requested with expected identifiers."""
|
||||
"""Test that datasource credentials are requested with correct parameters."""
|
||||
# Arrange
|
||||
expected_tenant_id = mock_document.tenant_id
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=expected_tenant_id,
|
||||
tenant_id=mock_document.tenant_id,
|
||||
credential_id=credential_id,
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
@@ -177,14 +581,16 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that missing credential_id is forwarded as None."""
|
||||
"""Test that task handles missing credential_id by passing None."""
|
||||
# Arrange
|
||||
mock_document.data_source_info_dict = {
|
||||
"notion_workspace_id": "workspace-id",
|
||||
"notion_page_id": "page-id",
|
||||
"notion_workspace_id": "ws123",
|
||||
"notion_page_id": "page123",
|
||||
"type": "page",
|
||||
"last_edited_time": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
@@ -196,3 +602,39 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
def test_index_processor_clean_called_with_correct_params(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_datasource_provider_service,
|
||||
mock_notion_extractor,
|
||||
mock_index_processor_factory,
|
||||
mock_indexing_runner,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document_segments,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""Test that index processor clean is called with correct parameters."""
|
||||
# Arrange
|
||||
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
|
||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
|
||||
mock_processor.clean.assert_called_once_with(
|
||||
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ Then, configure the environment variables. Create a file named `.env.local` in t
|
||||
cp .env.example .env.local
|
||||
```
|
||||
|
||||
```txt
|
||||
```
|
||||
# For production release, change this to PRODUCTION
|
||||
NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT
|
||||
# The deployment edition, SELF_HOSTED
|
||||
|
||||
@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
const sendEmail = async (email: string, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
|
||||
@@ -61,7 +61,8 @@ const ParamsConfig = ({
|
||||
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
||||
if (tempDataSetConfigs.reranking_enable
|
||||
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !isCurrentRerankModelValid) {
|
||||
&& !isCurrentRerankModelValid
|
||||
) {
|
||||
errMsg = t('datasetConfig.rerankModelRequired', { ns: 'appDebug' })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ describe('Input component', () => {
|
||||
|
||||
it('shows left icon when showLeftIcon is true', () => {
|
||||
render(<Input showLeftIcon />)
|
||||
const searchIcon = document.querySelector('.i-ri-search-line')
|
||||
const searchIcon = document.querySelector('svg')
|
||||
expect(searchIcon).toBeInTheDocument()
|
||||
const input = screen.getByPlaceholderText('Search')
|
||||
expect(input).toHaveClass('pl-[26px]')
|
||||
@@ -51,7 +51,7 @@ describe('Input component', () => {
|
||||
|
||||
it('shows clear icon when showClearIcon is true and has value', () => {
|
||||
render(<Input showClearIcon value="test" />)
|
||||
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
|
||||
const clearIcon = document.querySelector('.group svg')
|
||||
expect(clearIcon).toBeInTheDocument()
|
||||
const input = screen.getByDisplayValue('test')
|
||||
expect(input).toHaveClass('pr-[26px]')
|
||||
@@ -59,21 +59,21 @@ describe('Input component', () => {
|
||||
|
||||
it('does not show clear icon when disabled, even with value', () => {
|
||||
render(<Input showClearIcon value="test" disabled />)
|
||||
const clearIcon = document.querySelector('.i-ri-close-circle-fill')
|
||||
const clearIcon = document.querySelector('.group svg')
|
||||
expect(clearIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calls onClear when clear icon is clicked', () => {
|
||||
const onClear = vi.fn()
|
||||
render(<Input showClearIcon value="test" onClear={onClear} />)
|
||||
const clearIconContainer = screen.getByTestId('input-clear')
|
||||
const clearIconContainer = document.querySelector('.group')
|
||||
fireEvent.click(clearIconContainer!)
|
||||
expect(onClear).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('shows warning icon when destructive is true', () => {
|
||||
render(<Input destructive />)
|
||||
const warningIcon = document.querySelector('.i-ri-error-warning-line')
|
||||
const warningIcon = document.querySelector('svg')
|
||||
expect(warningIcon).toBeInTheDocument()
|
||||
const input = screen.getByPlaceholderText('Please input')
|
||||
expect(input).toHaveClass('border-components-input-border-destructive')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { VariantProps } from 'class-variance-authority'
|
||||
import type { ChangeEventHandler, CSSProperties, FocusEventHandler } from 'react'
|
||||
import { RiCloseCircleFill, RiErrorWarningLine, RiSearchLine } from '@remixicon/react'
|
||||
import { cva } from 'class-variance-authority'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import * as React from 'react'
|
||||
@@ -12,8 +13,8 @@ export const inputVariants = cva(
|
||||
{
|
||||
variants: {
|
||||
size: {
|
||||
regular: 'px-3 system-sm-regular radius-md',
|
||||
large: 'px-4 system-md-regular radius-lg',
|
||||
regular: 'px-3 radius-md system-sm-regular',
|
||||
large: 'px-4 radius-lg system-md-regular',
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
@@ -82,7 +83,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
}
|
||||
return (
|
||||
<div className={cn('relative w-full', wrapperClassName)}>
|
||||
{showLeftIcon && <span className={cn('i-ri-search-line absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
|
||||
{showLeftIcon && <RiSearchLine className={cn('absolute left-2 top-1/2 h-4 w-4 -translate-y-1/2 text-components-input-text-placeholder')} />}
|
||||
<input
|
||||
ref={ref}
|
||||
style={styleCss}
|
||||
@@ -114,11 +115,11 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
onClick={onClear}
|
||||
data-testid="input-clear"
|
||||
>
|
||||
<span className="i-ri-close-circle-fill h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
|
||||
<RiCloseCircleFill className="h-3.5 w-3.5 cursor-pointer text-text-quaternary group-hover:text-text-tertiary" />
|
||||
</div>
|
||||
)}
|
||||
{destructive && (
|
||||
<span className="i-ri-error-warning-line absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
|
||||
<RiErrorWarningLine className="absolute right-2 top-1/2 h-4 w-4 -translate-y-1/2 text-text-destructive-secondary" />
|
||||
)}
|
||||
{showCopyIcon && (
|
||||
<div className={cn('group absolute right-0 top-1/2 -translate-y-1/2 cursor-pointer')}>
|
||||
@@ -130,7 +131,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({
|
||||
)}
|
||||
{
|
||||
unit && (
|
||||
<div className="absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary system-sm-regular">
|
||||
<div className="system-sm-regular absolute right-2 top-1/2 -translate-y-1/2 text-text-tertiary">
|
||||
{unit}
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -334,10 +334,10 @@ describe('FileList', () => {
|
||||
it('should call resetKeywords prop when clear button is clicked', () => {
|
||||
const mockResetKeywords = vi.fn()
|
||||
const props = createDefaultProps({ resetKeywords: mockResetKeywords, keywords: 'to-reset' })
|
||||
render(<FileList {...props} />)
|
||||
const { container } = render(<FileList {...props} />)
|
||||
|
||||
// Act - Click the clear icon div (it contains RiCloseCircleFill icon)
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
@@ -346,12 +346,12 @@ describe('FileList', () => {
|
||||
|
||||
it('should reset inputValue to empty string when clear is clicked', () => {
|
||||
const props = createDefaultProps({ keywords: 'to-be-reset' })
|
||||
render(<FileList {...props} />)
|
||||
const { container } = render(<FileList {...props} />)
|
||||
const input = screen.getByPlaceholderText('datasetPipeline.onlineDrive.breadcrumbs.searchPlaceholder')
|
||||
fireEvent.change(input, { target: { value: 'some-search' } })
|
||||
|
||||
// Act - Find and click the clear icon
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ describe('Header', () => {
|
||||
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Assert - Input should have search icon class
|
||||
const searchIcon = container.querySelector('.i-ri-search-line.h-4.w-4')
|
||||
// Assert - Input should have search icon (RiSearchLine is rendered as svg)
|
||||
const searchIcon = container.querySelector('svg.h-4.w-4')
|
||||
expect(searchIcon).toBeInTheDocument()
|
||||
})
|
||||
|
||||
@@ -313,10 +313,10 @@ describe('Header', () => {
|
||||
inputValue: 'to-clear',
|
||||
handleResetKeywords: mockHandleResetKeywords,
|
||||
})
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act - Find and click the clear icon container
|
||||
const clearButton = screen.getByTestId('input-clear')
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
expect(clearButton).toBeInTheDocument()
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
@@ -325,19 +325,19 @@ describe('Header', () => {
|
||||
|
||||
it('should not show clear icon when inputValue is empty', () => {
|
||||
const props = createDefaultProps({ inputValue: '' })
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act & Assert - Clear icon should not be visible
|
||||
const clearIcon = screen.queryByTestId('input-clear')
|
||||
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
|
||||
expect(clearIcon).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should show clear icon when inputValue is not empty', () => {
|
||||
const props = createDefaultProps({ inputValue: 'some-value' })
|
||||
render(<Header {...props} />)
|
||||
const { container } = render(<Header {...props} />)
|
||||
|
||||
// Act & Assert - Clear icon should be visible
|
||||
const clearIcon = screen.getByTestId('input-clear')
|
||||
const clearIcon = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')
|
||||
expect(clearIcon).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@@ -570,12 +570,13 @@ describe('Header', () => {
|
||||
inputValue: 'to-clear',
|
||||
handleResetKeywords: mockHandleResetKeywords,
|
||||
})
|
||||
const { rerender } = render(<Header {...props} />)
|
||||
const { container, rerender } = render(<Header {...props} />)
|
||||
|
||||
// Act - Click clear, rerender, click again
|
||||
fireEvent.click(screen.getByTestId('input-clear'))
|
||||
const clearButton = container.querySelector('[class*="cursor-pointer"] svg[class*="h-3.5"]')?.parentElement
|
||||
fireEvent.click(clearButton!)
|
||||
rerender(<Header {...props} />)
|
||||
fireEvent.click(screen.getByTestId('input-clear'))
|
||||
fireEvent.click(clearButton!)
|
||||
|
||||
expect(mockHandleResetKeywords).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Completion App API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Completion アプリ API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup, Embed } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Advanced Chat App API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# 高度なチャットアプリ API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Chat App API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# チャットアプリ API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Workflow App API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# ワークフローアプリ API
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { CodeGroup } from '../code.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
||||
import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from '../md.tsx'
|
||||
|
||||
# Workflow 应用 API
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import type { AppRouterInstance } from 'next/dist/shared/lib/app-router-context.shared-runtime'
|
||||
import type { AppContextValue } from '@/context/app-context'
|
||||
import type { ModalContextState } from '@/context/modal-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { AppRouterContext } from 'next/dist/shared/lib/app-router-context.shared-runtime'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
@@ -49,14 +50,6 @@ vi.mock('@/service/use-common', () => ({
|
||||
useLogout: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('next/navigation', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('next/navigation')>()
|
||||
return {
|
||||
...actual,
|
||||
useRouter: vi.fn(),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useDocLink: () => (path: string) => `https://docs.dify.ai${path}`,
|
||||
}))
|
||||
@@ -126,6 +119,15 @@ describe('AccountDropdown', () => {
|
||||
const mockSetShowAccountSettingModal = vi.fn()
|
||||
|
||||
const renderWithRouter = (ui: React.ReactElement) => {
|
||||
const mockRouter = {
|
||||
push: mockPush,
|
||||
replace: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
back: vi.fn(),
|
||||
forward: vi.fn(),
|
||||
refresh: vi.fn(),
|
||||
} as unknown as AppRouterInstance
|
||||
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
@@ -136,7 +138,9 @@ describe('AccountDropdown', () => {
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
<AppRouterContext.Provider value={mockRouter}>
|
||||
{ui}
|
||||
</AppRouterContext.Provider>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
@@ -162,14 +166,6 @@ describe('AccountDropdown', () => {
|
||||
vi.mocked(useLogout).mockReturnValue({
|
||||
mutateAsync: mockLogout,
|
||||
} as unknown as ReturnType<typeof useLogout>)
|
||||
vi.mocked(useRouter).mockReturnValue({
|
||||
push: mockPush,
|
||||
replace: vi.fn(),
|
||||
prefetch: vi.fn(),
|
||||
back: vi.fn(),
|
||||
forward: vi.fn(),
|
||||
refresh: vi.fn(),
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
'use client'
|
||||
import type { AccountSettingTab } from '@/app/components/header/account-setting/constants'
|
||||
import {
|
||||
RiBrain2Fill,
|
||||
RiBrain2Line,
|
||||
RiCloseLine,
|
||||
RiColorFilterFill,
|
||||
RiColorFilterLine,
|
||||
RiDatabase2Fill,
|
||||
RiDatabase2Line,
|
||||
RiGroup2Fill,
|
||||
RiGroup2Line,
|
||||
RiMoneyDollarCircleFill,
|
||||
RiMoneyDollarCircleLine,
|
||||
RiPuzzle2Fill,
|
||||
RiPuzzle2Line,
|
||||
RiTranslate2,
|
||||
} from '@remixicon/react'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import SearchInput from '@/app/components/base/search-input'
|
||||
import Input from '@/app/components/base/input'
|
||||
import BillingPage from '@/app/components/billing/billing-page'
|
||||
import CustomPage from '@/app/components/custom/custom-page'
|
||||
import {
|
||||
@@ -60,14 +76,14 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.PROVIDER,
|
||||
name: t('settings.provider', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-brain-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-brain-2-fill', iconClassName)} />,
|
||||
icon: <RiBrain2Line className={iconClassName} />,
|
||||
activeIcon: <RiBrain2Fill className={iconClassName} />,
|
||||
},
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.MEMBERS,
|
||||
name: t('settings.members', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-group-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-group-2-fill', iconClassName)} />,
|
||||
icon: <RiGroup2Line className={iconClassName} />,
|
||||
activeIcon: <RiGroup2Fill className={iconClassName} />,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -76,8 +92,8 @@ export default function AccountSetting({
|
||||
key: ACCOUNT_SETTING_TAB.BILLING,
|
||||
name: t('settings.billing', { ns: 'common' }),
|
||||
description: t('plansCommon.receiptInfo', { ns: 'billing' }),
|
||||
icon: <span className={cn('i-ri-money-dollar-circle-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-money-dollar-circle-fill', iconClassName)} />,
|
||||
icon: <RiMoneyDollarCircleLine className={iconClassName} />,
|
||||
activeIcon: <RiMoneyDollarCircleFill className={iconClassName} />,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -85,14 +101,14 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.DATA_SOURCE,
|
||||
name: t('settings.dataSource', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-database-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-database-2-fill', iconClassName)} />,
|
||||
icon: <RiDatabase2Line className={iconClassName} />,
|
||||
activeIcon: <RiDatabase2Fill className={iconClassName} />,
|
||||
},
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.API_BASED_EXTENSION,
|
||||
name: t('settings.apiBasedExtension', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-puzzle-2-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-puzzle-2-fill', iconClassName)} />,
|
||||
icon: <RiPuzzle2Line className={iconClassName} />,
|
||||
activeIcon: <RiPuzzle2Fill className={iconClassName} />,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -100,8 +116,8 @@ export default function AccountSetting({
|
||||
items.push({
|
||||
key: ACCOUNT_SETTING_TAB.CUSTOM,
|
||||
name: t('custom', { ns: 'custom' }),
|
||||
icon: <span className={cn('i-ri-color-filter-line', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-color-filter-fill', iconClassName)} />,
|
||||
icon: <RiColorFilterLine className={iconClassName} />,
|
||||
activeIcon: <RiColorFilterFill className={iconClassName} />,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -124,8 +140,8 @@ export default function AccountSetting({
|
||||
{
|
||||
key: ACCOUNT_SETTING_TAB.LANGUAGE,
|
||||
name: t('settings.language', { ns: 'common' }),
|
||||
icon: <span className={cn('i-ri-translate-2', iconClassName)} />,
|
||||
activeIcon: <span className={cn('i-ri-translate-2', iconClassName)} />,
|
||||
icon: <RiTranslate2 className={iconClassName} />,
|
||||
activeIcon: <RiTranslate2 className={iconClassName} />,
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -155,13 +171,13 @@ export default function AccountSetting({
|
||||
>
|
||||
<div className="mx-auto flex h-[100vh] max-w-[1048px]">
|
||||
<div className="flex w-[44px] flex-col border-r border-divider-burn pl-4 pr-6 sm:w-[224px]">
|
||||
<div className="mb-8 mt-6 px-3 py-2 text-text-primary title-2xl-semi-bold">{t('userProfile.settings', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold mb-8 mt-6 px-3 py-2 text-text-primary">{t('userProfile.settings', { ns: 'common' })}</div>
|
||||
<div className="w-full">
|
||||
{
|
||||
menuItems.map(menuItem => (
|
||||
<div key={menuItem.key} className="mb-2">
|
||||
{!isCurrentWorkspaceDatasetOperator && (
|
||||
<div className="mb-0.5 py-2 pb-1 pl-3 text-text-tertiary system-xs-medium-uppercase">{menuItem.name}</div>
|
||||
<div className="system-xs-medium-uppercase mb-0.5 py-2 pb-1 pl-3 text-text-tertiary">{menuItem.name}</div>
|
||||
)}
|
||||
<div>
|
||||
{
|
||||
@@ -170,7 +186,7 @@ export default function AccountSetting({
|
||||
key={item.key}
|
||||
className={cn(
|
||||
'mb-0.5 flex h-[37px] cursor-pointer items-center rounded-lg p-1 pl-3 text-sm',
|
||||
activeMenu === item.key ? 'bg-state-base-active text-components-menu-item-text-active system-sm-semibold' : 'text-components-menu-item-text system-sm-medium',
|
||||
activeMenu === item.key ? 'system-sm-semibold bg-state-base-active text-components-menu-item-text-active' : 'system-sm-medium text-components-menu-item-text',
|
||||
)}
|
||||
title={item.name}
|
||||
onClick={() => {
|
||||
@@ -197,36 +213,38 @@ export default function AccountSetting({
|
||||
className="px-2"
|
||||
onClick={onCancel}
|
||||
>
|
||||
<span className="i-ri-close-line h-5 w-5" />
|
||||
<RiCloseLine className="h-5 w-5" />
|
||||
</Button>
|
||||
<div className="mt-1 text-text-tertiary system-2xs-medium-uppercase">ESC</div>
|
||||
<div className="system-2xs-medium-uppercase mt-1 text-text-tertiary">ESC</div>
|
||||
</div>
|
||||
<div ref={scrollRef} className="w-full overflow-y-auto bg-components-panel-bg pb-4">
|
||||
<div className={cn('sticky top-0 z-20 mx-8 mb-[18px] flex items-center bg-components-panel-bg pb-2 pt-[27px]', scrolled && 'border-b border-divider-regular')}>
|
||||
<div className="shrink-0 text-text-primary title-2xl-semi-bold">
|
||||
<div className="title-2xl-semi-bold shrink-0 text-text-primary">
|
||||
{activeItem?.name}
|
||||
{activeItem?.description && (
|
||||
<div className="mt-1 text-text-tertiary system-sm-regular">{activeItem?.description}</div>
|
||||
<div className="system-sm-regular mt-1 text-text-tertiary">{activeItem?.description}</div>
|
||||
)}
|
||||
</div>
|
||||
{activeItem?.key === ACCOUNT_SETTING_TAB.PROVIDER && (
|
||||
{activeItem?.key === 'provider' && (
|
||||
<div className="flex grow justify-end">
|
||||
<SearchInput
|
||||
className="w-[200px]"
|
||||
onChange={setSearchValue}
|
||||
<Input
|
||||
showLeftIcon
|
||||
wrapperClassName="!w-[200px]"
|
||||
className="!h-8 !text-[13px]"
|
||||
onChange={e => setSearchValue(e.target.value)}
|
||||
value={searchValue}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="px-4 pt-2 sm:px-8">
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.PROVIDER && <ModelProviderPage searchText={searchValue} />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.MEMBERS && <MembersPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.BILLING && <BillingPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.DATA_SOURCE && <DataSourcePage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.API_BASED_EXTENSION && <ApiBasedExtensionPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.CUSTOM && <CustomPage />}
|
||||
{activeMenu === ACCOUNT_SETTING_TAB.LANGUAGE && <LanguagePage />}
|
||||
{activeMenu === 'provider' && <ModelProviderPage searchText={searchValue} />}
|
||||
{activeMenu === 'members' && <MembersPage />}
|
||||
{activeMenu === 'billing' && <BillingPage />}
|
||||
{activeMenu === 'data-source' && <DataSourcePage />}
|
||||
{activeMenu === 'api-based-extension' && <ApiBasedExtensionPage />}
|
||||
{activeMenu === 'custom' && <CustomPage />}
|
||||
{activeMenu === 'language' && <LanguagePage />}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -46,7 +46,8 @@ const nodeDefault: NodeDefault<HttpNodeType> = {
|
||||
|
||||
if (!errorMessages
|
||||
&& payload.body.type === BodyType.binary
|
||||
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)) {
|
||||
&& ((!(payload.body.data as BodyPayload)[0]?.file) || (payload.body.data as BodyPayload)[0]?.file?.length === 0)
|
||||
) {
|
||||
errorMessages = t('errorMsg.fieldRequired', { ns: 'workflow', field: t('nodes.http.binaryFileVariable', { ns: 'workflow' }) })
|
||||
}
|
||||
|
||||
|
||||
@@ -2052,6 +2052,9 @@
|
||||
"app/components/base/input/index.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/base/linked-apps-panel/index.tsx": {
|
||||
@@ -3991,6 +3994,9 @@
|
||||
"app/components/header/account-setting/index.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 1
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 7
|
||||
}
|
||||
},
|
||||
"app/components/header/account-setting/key-validator/declarations.ts": {
|
||||
@@ -8163,6 +8169,11 @@
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"i18n-config/README.md": {
|
||||
"no-irregular-whitespace": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"i18n/de-DE/billing.json": {
|
||||
"no-irregular-whitespace": {
|
||||
"count": 1
|
||||
|
||||
@@ -6,7 +6,7 @@ This directory contains i18n tooling and configuration. Translation files live u
|
||||
|
||||
## File Structure
|
||||
|
||||
```txt
|
||||
```
|
||||
web/i18n
|
||||
├── en-US
|
||||
│ ├── app.json
|
||||
@@ -36,7 +36,7 @@ By default we will use `LanguagesSupported` to determine which languages are sup
|
||||
|
||||
1. Create a new folder for the new language.
|
||||
|
||||
```txt
|
||||
```
|
||||
cd web/i18n
|
||||
cp -r en-US id-ID
|
||||
```
|
||||
@@ -98,7 +98,7 @@ export const languages = [
|
||||
{
|
||||
value: 'ru-RU',
|
||||
name: 'Русский(Россия)',
|
||||
example: 'Привет, Dify!',
|
||||
example: ' Привет, Dify!',
|
||||
supported: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"type": "module",
|
||||
"version": "1.13.0",
|
||||
"private": true,
|
||||
"packageManager": "pnpm@10.27.0",
|
||||
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
|
||||
"imports": {
|
||||
"#i18n": {
|
||||
"react-server": "./i18n-config/lib.server.ts",
|
||||
@@ -28,12 +28,9 @@
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
"dev:inspect": "next dev --inspect",
|
||||
"dev:vinext": "vinext dev",
|
||||
"build": "next build",
|
||||
"build:docker": "next build && node scripts/optimize-standalone.js",
|
||||
"build:vinext": "vinext build",
|
||||
"start": "node ./scripts/copy-and-start.mjs",
|
||||
"start:vinext": "vinext start",
|
||||
"lint": "eslint --cache --concurrency=auto",
|
||||
"lint:ci": "eslint --cache --concurrency 2",
|
||||
"lint:fix": "pnpm lint --fix",
|
||||
@@ -168,25 +165,24 @@
|
||||
"zustand": "5.0.9"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@antfu/eslint-config": "7.6.1",
|
||||
"@chromatic-com/storybook": "5.0.1",
|
||||
"@antfu/eslint-config": "7.2.0",
|
||||
"@chromatic-com/storybook": "5.0.0",
|
||||
"@egoist/tailwindcss-icons": "1.9.2",
|
||||
"@eslint-react/eslint-plugin": "2.13.0",
|
||||
"@eslint-react/eslint-plugin": "2.9.4",
|
||||
"@iconify-json/heroicons": "1.2.3",
|
||||
"@iconify-json/ri": "1.2.9",
|
||||
"@mdx-js/loader": "3.1.1",
|
||||
"@mdx-js/react": "3.1.1",
|
||||
"@mdx-js/rollup": "3.1.1",
|
||||
"@next/eslint-plugin-next": "16.1.6",
|
||||
"@next/mdx": "16.1.5",
|
||||
"@rgrove/parse-xml": "4.2.0",
|
||||
"@serwist/turbopack": "9.5.4",
|
||||
"@storybook/addon-docs": "10.2.13",
|
||||
"@storybook/addon-links": "10.2.13",
|
||||
"@storybook/addon-onboarding": "10.2.13",
|
||||
"@storybook/addon-themes": "10.2.13",
|
||||
"@storybook/nextjs-vite": "10.2.13",
|
||||
"@storybook/react": "10.2.13",
|
||||
"@storybook/addon-docs": "10.2.0",
|
||||
"@storybook/addon-links": "10.2.0",
|
||||
"@storybook/addon-onboarding": "10.2.0",
|
||||
"@storybook/addon-themes": "10.2.0",
|
||||
"@storybook/nextjs-vite": "10.2.0",
|
||||
"@storybook/react": "10.2.0",
|
||||
"@tanstack/eslint-plugin-query": "5.91.4",
|
||||
"@tanstack/react-devtools": "0.9.2",
|
||||
"@tanstack/react-form-devtools": "0.2.12",
|
||||
@@ -212,22 +208,21 @@
|
||||
"@types/semver": "7.7.1",
|
||||
"@types/sortablejs": "1.15.8",
|
||||
"@types/uuid": "10.0.0",
|
||||
"@typescript-eslint/parser": "8.56.1",
|
||||
"@typescript-eslint/parser": "8.54.0",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251209.1",
|
||||
"@vitejs/plugin-react": "5.1.4",
|
||||
"@vitejs/plugin-rsc": "0.5.21",
|
||||
"@vitest/coverage-v8": "4.0.18",
|
||||
"@vitejs/plugin-react": "5.1.2",
|
||||
"@vitest/coverage-v8": "4.0.17",
|
||||
"autoprefixer": "10.4.21",
|
||||
"code-inspector-plugin": "1.3.6",
|
||||
"cross-env": "10.1.0",
|
||||
"esbuild": "0.27.2",
|
||||
"eslint": "10.0.2",
|
||||
"eslint-plugin-better-tailwindcss": "4.3.1",
|
||||
"eslint-plugin-hyoban": "0.11.2",
|
||||
"eslint": "9.39.2",
|
||||
"eslint-plugin-better-tailwindcss": "https://pkg.pr.new/hyoban/eslint-plugin-better-tailwindcss@c0161c7",
|
||||
"eslint-plugin-hyoban": "0.11.1",
|
||||
"eslint-plugin-react-hooks": "7.0.1",
|
||||
"eslint-plugin-react-refresh": "0.5.2",
|
||||
"eslint-plugin-sonarjs": "4.0.0",
|
||||
"eslint-plugin-storybook": "10.2.13",
|
||||
"eslint-plugin-react-refresh": "0.5.0",
|
||||
"eslint-plugin-sonarjs": "3.0.6",
|
||||
"eslint-plugin-storybook": "10.2.6",
|
||||
"husky": "9.1.7",
|
||||
"iconify-import-svg": "0.1.1",
|
||||
"jsdom": "27.3.0",
|
||||
@@ -238,25 +233,22 @@
|
||||
"postcss": "8.5.6",
|
||||
"postcss-js": "5.0.3",
|
||||
"react-scan": "0.4.3",
|
||||
"react-server-dom-webpack": "19.2.4",
|
||||
"sass": "1.93.2",
|
||||
"serwist": "9.5.4",
|
||||
"storybook": "10.2.13",
|
||||
"storybook": "10.2.0",
|
||||
"tailwindcss": "3.4.19",
|
||||
"tsx": "4.21.0",
|
||||
"typescript": "5.9.3",
|
||||
"uglify-js": "3.19.3",
|
||||
"vinext": "https://pkg.pr.new/hyoban/vinext@e283197",
|
||||
"vite": "7.3.1",
|
||||
"vite-tsconfig-paths": "6.1.1",
|
||||
"vitest": "4.0.18",
|
||||
"vite-tsconfig-paths": "6.0.4",
|
||||
"vitest": "4.0.17",
|
||||
"vitest-canvas-mock": "1.1.3"
|
||||
},
|
||||
"pnpm": {
|
||||
"overrides": {
|
||||
"@monaco-editor/loader": "1.5.0",
|
||||
"@nolyfill/safe-buffer": "npm:safe-buffer@^5.2.1",
|
||||
"@stylistic/eslint-plugin": "https://pkg.pr.new/@stylistic/eslint-plugin@258f9d8",
|
||||
"array-includes": "npm:@nolyfill/array-includes@^1",
|
||||
"array.prototype.findlast": "npm:@nolyfill/array.prototype.findlast@^1",
|
||||
"array.prototype.findlastindex": "npm:@nolyfill/array.prototype.findlastindex@^1",
|
||||
|
||||
2576
web/pnpm-lock.yaml
generated
2576
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
|
||||
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
|
||||
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
|
||||
|
||||
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
post<CommonResponse & { data: string }>('/account/change-email', { body })
|
||||
|
||||
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>
|
||||
|
||||
@@ -1,60 +1,16 @@
|
||||
import type { Plugin } from 'vite'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import vinext from 'vinext'
|
||||
import { defineConfig } from 'vite'
|
||||
import tsconfigPaths from 'vite-tsconfig-paths'
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url))
|
||||
const isCI = !!process.env.CI
|
||||
|
||||
export default defineConfig(({ mode }) => {
|
||||
return {
|
||||
plugins: mode === 'test'
|
||||
? [
|
||||
tsconfigPaths(),
|
||||
react(),
|
||||
{
|
||||
// Stub .mdx files so components importing them can be unit-tested
|
||||
name: 'mdx-stub',
|
||||
enforce: 'pre',
|
||||
transform(_, id) {
|
||||
if (id.endsWith('.mdx'))
|
||||
return { code: 'export default () => null', map: null }
|
||||
},
|
||||
} as Plugin,
|
||||
]
|
||||
: [
|
||||
vinext(),
|
||||
],
|
||||
resolve: {
|
||||
alias: {
|
||||
'~@': __dirname,
|
||||
},
|
||||
export default defineConfig({
|
||||
plugins: [tsconfigPaths(), react()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'~@': __dirname,
|
||||
},
|
||||
|
||||
// vinext related config
|
||||
...(mode !== 'test'
|
||||
? {
|
||||
optimizeDeps: {
|
||||
exclude: ['nuqs'],
|
||||
},
|
||||
server: {
|
||||
port: 3000,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
|
||||
// Vitest config
|
||||
test: {
|
||||
environment: 'jsdom',
|
||||
globals: true,
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
27
web/vitest.config.ts
Normal file
27
web/vitest.config.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { defineConfig, mergeConfig } from 'vitest/config'
|
||||
import viteConfig from './vite.config'
|
||||
|
||||
const isCI = !!process.env.CI
|
||||
|
||||
export default mergeConfig(viteConfig, defineConfig({
|
||||
plugins: [
|
||||
{
|
||||
// Stub .mdx files so components importing them can be unit-tested
|
||||
name: 'mdx-stub',
|
||||
enforce: 'pre',
|
||||
transform(_, id) {
|
||||
if (id.endsWith('.mdx'))
|
||||
return { code: 'export default () => null', map: null }
|
||||
},
|
||||
},
|
||||
],
|
||||
test: {
|
||||
environment: 'jsdom',
|
||||
globals: true,
|
||||
setupFiles: ['./vitest.setup.ts'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: isCI ? ['json', 'json-summary'] : ['text', 'json', 'json-summary'],
|
||||
},
|
||||
},
|
||||
}))
|
||||
Reference in New Issue
Block a user